From 44d46149e7f12cf4f712332f70fad0a6f78a0299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 10 Sep 2024 23:46:02 +0200 Subject: [PATCH 01/16] merge --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index cc18cc91b..277af7847 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,8 @@ - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) +- restructure `ot.unbalanced` module (PR #658) +- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) From 20ea01ca107a47a66190af591dd17c09b28dc092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 7 Nov 2024 11:41:33 +0100 Subject: [PATCH 02/16] new dev version --- RELEASES.md | 3 +++ ot/__init__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 846c906b0..63d77bf19 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,8 @@ # Releases +## 0.9.6dev + + ## 0.9.5 *November 2024* diff --git a/ot/__init__.py b/ot/__init__.py index 2b0425a93..5e21d6a76 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -74,7 +74,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.5" +__version__ = "0.9.6dev0" __all__ = [ "emd", From 02009d7d97bc77505d0e49b260ba31bde3b642a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 12 Nov 2024 00:59:53 +0100 Subject: [PATCH 03/16] first commit partial fgw --- ot/gromov/__init__.py | 4 + ot/gromov/_partial.py | 486 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 490 insertions(+) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index f552cb914..b7520e099 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -102,6 +102,8 @@ from ._partial import ( partial_gromov_wasserstein, partial_gromov_wasserstein2, + partial_fused_gromov_wasserstein, + partial_fused_gromov_wasserstein2, solve_partial_gromov_linesearch, entropic_partial_gromov_wasserstein, entropic_partial_gromov_wasserstein2, @@ -173,6 +175,8 @@ "fused_unbalanced_across_spaces_divergence", "partial_gromov_wasserstein", "partial_gromov_wasserstein2", + "partial_fused_gromov_wasserstein", + "partial_fused_gromov_wasserstein2", "solve_partial_gromov_linesearch", "entropic_partial_gromov_wasserstein", "entropic_partial_gromov_wasserstein2", diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index e38eeff1c..da9682e25 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -495,6 +495,492 @@ def partial_gromov_wasserstein2( return pgw +def partial_fused_gromov_wasserstein( + M, + C1, + C2, + p=None, + q=None, + m=None, + loss_fun="square_loss", + alpha=0.5, + nb_dummies=1, + G0=None, + thres=1, + numItermax=1e4, + tol=1e-8, + symmetric=None, + warn=True, + log=False, + verbose=False, + **kwargs, +): + r""" + Returns the Partial Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` + and :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem using Conditional Gradient: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space. + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space. + - :math:`\mathbf{p}`: Distribution in the source space. + - :math:`\mathbf{q}`: Distribution in the target space. + - `m` is the amount of mass to be transported + - `L`: Loss function to account for the misfit between the similarity matrices. + + The formulation of the problem has been proposed in + :ref:`[29] ` + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric costfr matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + m : float, optional + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + nb_dummies : int, optional + Number of dummy points to add (avoid instabilities in the EMD solver) + G0 : array-like, shape (ns, nt), optional + Initialization of the transportation matrix + thres : float, optional + quantile of the gradient matrix to populate the cost matrix when 0 + (default: 1) + numItermax : int, optional + Max number of iterations + tol : float, optional + tolerance for stopping iterations + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + warn: bool, optional. + Whether to raise a warning when EMD did not converge. + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the emd solver + + + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal transport matrix between the two spaces. + + log : dict + Convergence information and loss. + + .. _references-partial-gromov-wasserstein: + References + ---------- + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + """ + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(C2.shape[0], type_as=C1) + if G0 is not None: + G0_ = G0 + arr.append(G0) + + nx = get_backend(*arr) + p0, q0, M0, C10, C20 = p, q, M, C1, C2 + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) + M = nx.to_numpy(M0) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) + + if m is None: + m = min(np.sum(p), np.sum(q)) + elif m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + elif m > min(np.sum(p), np.sum(q)): + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) + + if G0 is None: + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + assert np.all(G0.sum(1) <= p) + assert np.all(G0.sum(0) <= q) + + q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) + p_extended = np.append(p, [(np.sum(q) - m) / nb_dummies] * nb_dummies) + + # cg for GW is implemented using numpy on CPU + np_ = NumpyBackend() + + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, np_) + fC2t = fC2.T + if not symmetric: + fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T + + ones_p = np_.ones(p.shape[0], type_as=p) + ones_q = np_.ones(q.shape[0], type_as=q) + + def f(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + return gwloss(constC1 + constC2, hC1, hC2, G, np_) + + if symmetric: + + def df(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + return gwggrad(constC1 + constC2, hC1, hC2, G, np_) + else: + + def df(G): + pG = G.sum(1) + qG = G.sum(0) + constC1 = np.outer(np.dot(fC1, pG), ones_q) + constC2 = np.outer(ones_p, np.dot(qG, fC2t)) + constC1t = np.outer(np.dot(fC1t, pG), ones_q) + constC2t = np.outer(ones_p, np.dot(qG, fC2)) + + return 0.5 * ( + gwggrad(constC1 + constC2, hC1, hC2, G, np_) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, np_) + ) + + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): + df_Gc = df(deltaG + G) + return solve_partial_gromov_linesearch( + G, + deltaG, + cost_G, + df_G, + df_Gc, + M=(1 - alpha) * M, + reg=alpha, + nx=np_, + **kwargs, + ) + + if not nx.is_floating_point(C10): + warnings.warn( + "Input structure matrix consists of integers. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "structure matrix consists of floating point elements.", + stacklevel=2, + ) + + if log: + res, log = partial_cg( + p, + q, + p_extended, + q_extended, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=True, + numItermax=numItermax, + stopThr=tol, + stopThr2=0.0, + warn=warn, + **kwargs, + ) + log["partial_fgw_dist"] = nx.from_numpy(log["loss"][-1], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: + return nx.from_numpy( + partial_cg( + p, + q, + p_extended, + q_extended, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=False, + numItermax=numItermax, + stopThr=tol, + stopThr2=0.0, + **kwargs, + ), + type_as=C10, + ) + + +def partial_fused_gromov_wasserstein2( + M, + C1, + C2, + p=None, + q=None, + m=None, + loss_fun="square_loss", + alpha=0.5, + nb_dummies=1, + G0=None, + thres=1, + numItermax=1e4, + tol=1e-7, + symmetric=None, + warn=False, + log=False, + verbose=False, + **kwargs, +): + r""" + Returns the Partial Fused Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` + and :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem using Conditional Gradient: + + .. math:: + \mathbf{PFGW}_{\alpha} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space. + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space. + - :math:`\mathbf{p}`: Distribution in the source space. + - :math:`\mathbf{q}`: Distribution in the target space. + - `m` is the amount of mass to be transported + - `L`: Loss function to account for the misfit between the similarity matrices. + + + The formulation of the problem has been proposed in + :ref:`[29] ` + + Note that when using backends, this loss function is differentiable wrt the + matrices (M, C1, C2). + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + .. note:: All computations in the conjugate gradient solver are done with + numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + m : float, optional + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + nb_dummies : int, optional + Number of dummy points to add (avoid instabilities in the EMD solver) + G0 : ndarray, shape (ns, nt), optional + Initialization of the transportation matrix + thres : float, optional + quantile of the gradient matrix to populate the cost matrix when 0 + (default: 1) + numItermax : int, optional + Max number of iterations + tol : float, optional + tolerance for stopping iterations + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + warn: bool, optional. + Whether to raise a warning when EMD did not converge. + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the emd solver + + + .. warning:: + When dealing with a large number of points, the EMD solver may face + some instabilities, especially when the mass associated to the dummy + point is large. To avoid them, increase the number of dummy points + (allows a smoother repartition of the mass over the points). + + + Returns + ------- + partial_fgw_dist : float + partial FGW discrepancy + log : dict + log dictionary returned only if `log` is `True` + + .. _references-partial-gromov-wasserstein2: + References + ---------- + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + """ + # simple get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(M, C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C1) + + T, log_pfgw = partial_fused_gromov_wasserstein( + M, + C1, + C2, + p, + q, + m, + loss_fun, + alpha, + nb_dummies, + G0, + thres, + numItermax, + tol, + symmetric, + warn, + True, + verbose, + **kwargs, + ) + + log_pfgw["T"] = T + pfgw = log_pfgw["partial_fgw_dist"] + + # compute separate terms for gradients and log + lin_term = nx.sum(T * M) + log_pfgw["quad_loss"] = pfgw - (1 - alpha) * lin_term + log_pfgw["lin_loss"] = lin_term * (1 - alpha) + pgw_term = log_pfgw["quad_loss"] / alpha + + if loss_fun == "square_loss": + gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) + gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + if isinstance(alpha, int) or isinstance(alpha, float): + pfgw = nx.set_gradients( + pfgw, (M, C1, C2), ((1 - alpha) * T, alpha * gC1, alpha * gC2) + ) + else: + pfgw = nx.set_gradients( + pfgw, + (M, C1, C2, alpha), + ((1 - alpha) * T, alpha * gC1, alpha * gC2, pgw_term - lin_term), + ) + if log: + return pfgw, log_pfgw + else: + return pfgw + + def solve_partial_gromov_linesearch( G, deltaG, From a3b33bc629e044b1ffee90ff3ebc5ca9e702b730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 14 Nov 2024 01:29:05 +0100 Subject: [PATCH 04/16] complete tests + solve_gromov --- ot/gromov/_partial.py | 4 +- ot/solvers.py | 33 +++++- test/gromov/test_partial.py | 216 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 2 - 4 files changed, 251 insertions(+), 4 deletions(-) diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index da9682e25..0c8ec4794 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -637,13 +637,15 @@ def partial_fused_gromov_wasserstein( arr.append(G0) nx = get_backend(*arr) - p0, q0, M0, C10, C20 = p, q, M, C1, C2 + p0, q0, M0, C10, C20, alpha0 = p, q, M, C1, C2, alpha p = nx.to_numpy(p0) q = nx.to_numpy(q0) M = nx.to_numpy(M0) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) + alpha = nx.to_numpy(alpha0) + if symmetric is None: symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( C2, C2.T, atol=1e-10 diff --git a/ot/solvers.py b/ot/solvers.py index ec56d1330..508f248d5 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -23,6 +23,7 @@ entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2, partial_gromov_wasserstein2, + partial_fused_gromov_wasserstein2, entropic_partial_gromov_wasserstein2, ) from .gaussian import empirical_bures_wasserstein_distance @@ -779,6 +780,7 @@ def solve_gromov( .. code-block:: python res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 + res = ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.8, alpha=0.5) # partial FGW with m=0.8 .. _references-solve-gromov: @@ -1002,7 +1004,36 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO else: # partial FGW - raise (NotImplementedError("Partial FGW not implemented yet")) + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError("Partial FGW mass given in reg is too large")) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + m=unbalanced, + loss_fun=loss_fun, + alpha=alpha, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py index 1ae4e960f..1834de298 100644 --- a/test/gromov/test_partial.py +++ b/test/gromov/test_partial.py @@ -7,6 +7,7 @@ # License: MIT License import numpy as np +import torch import scipy as sp import ot import pytest @@ -42,6 +43,12 @@ def test_raise_errors(): with pytest.raises(ValueError): ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True) + with pytest.raises(ValueError): + ot.gromov.partial_fused_gromov_wasserstein(M, M, M, p, q, m=2, log=True) + + with pytest.raises(ValueError): + ot.gromov.partial_fused_gromov_wasserstein(M, M, M, p, q, m=-1, log=True) + def test_partial_gromov_wasserstein(nx): rng = np.random.RandomState(42) @@ -235,6 +242,215 @@ def test_partial_partial_gromov_linesearch(nx): np.testing.assert_allclose(alpha, 1.0, rtol=1e-4) +def test_partial_fused_gromov_wasserstein(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + psub = ot.unif(n_samples - 5 + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # clean samples + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + # add noise + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + F1 = xs + + C1sub = ot.dist(xs[5:], xs[5:]) + F1sub = xs[5:] + + C2 = ot.dist(xt, xt) + F2 = xs + + C3 = ot.dist(xt2, xt2) + F3 = xt2 + + M11sub = ot.dist(F1, F1sub) + M12 = ot.dist(F1, F2) + M13 = ot.dist(F1, F3) + + m = 2.0 / 3.0 + + M11subb, M12b, M13b, C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy( + M11sub, M12, M13, C1, C1sub, C2, C3, p, psub, q + ) + + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # check consistency across backends and stability w.r.t loss/marginals/sym + list_sym = [True, None] + for i, loss_fun in enumerate(["square_loss", "kl_loss"]): + res, log = ot.gromov.partial_fused_gromov_wasserstein( + M13, + C1, + C3, + p=p, + q=None, + m=m, + loss_fun=loss_fun, + alpha=0.3, + n_dummies=1, + G0=G0, + log=True, + symmetric=list_sym[i], + warn=True, + verbose=True, + ) + + resb, logb = ot.gromov.partial_fused_gromov_wasserstein( + M13b, + C1b, + C3b, + p=None, + q=qb, + m=m, + loss_fun=loss_fun, + alpha=0.3, + n_dummies=1, + G0=G0b, + log=True, + symmetric=False, + warn=True, + verbose=True, + ) + + resb_ = nx.to_numpy(resb) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= q) # cf convergence wasserstein + + try: + # precision error while doubling numbers of computations with symmetric=False + # some instability can occur with kl. to investigate further. + # changing log offset in _transform_matrix was a way to solve it + # but it also negatively affects some other solvers in the API + np.testing.assert_allclose(res, resb_, rtol=1e-4) + except AssertionError: + pass + + # tests with different number of samples across spaces + m = 2.0 / 3.0 + res, log = ot.gromov.partial_fused_gromov_wasserstein( + M11sub, C1, C1sub, p=p, q=psub, m=m, log=True + ) + + resb, logb = ot.gromov.partial_fused_gromov_wasserstein( + M11subb, C1b, C1subb, p=pb, q=psubb, m=m, log=True + ) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= psub) # cf convergence wasserstein + np.testing.assert_allclose(np.sum(res), m, atol=1e-15) + + # Edge cases - tests with m=1 set by default (coincide with gw) + m = 1 + res0 = ot.gromov.partial_fused_gromov_wasserstein(M12, C1, C2, p, q, m=m, log=False) + res0b, log0b = ot.gromov.partial_fused_gromov_wasserstein( + M12b, C1b, C2b, pb, qb, m=None, log=True + ) + G = ot.gromov.fused_gromov_wasserstein(M12, C1, C2, p, q, "square_loss") + + np.testing.assert_allclose(G, res0, rtol=1e-4) + np.testing.assert_allclose(res0b, res0, rtol=1e-4) + + # tests for pGW2 + for loss_fun in ["square_loss", "kl_loss"]: + w0, log0 = ot.gromov.partial_fused_gromov_wasserstein2( + M12, C1, C2, p=None, q=q, m=m, loss_fun=loss_fun, log=True + ) + w0_val = ot.gromov.partial_fused_gromov_wasserstein2( + M12b, C1b, C2b, p=pb, q=None, m=m, loss_fun=loss_fun, log=False + ) + np.testing.assert_allclose(w0, w0_val, rtol=1e-4) + + # tests integers + C1_int = C1.astype(int) + C1b_int = nx.from_numpy(C1_int) + C2_int = C2.astype(int) + C2b_int = nx.from_numpy(C2_int) + + res0b, log0b = ot.gromov.partial_fused_gromov_wasserstein( + M12b, C1b_int, C2b_int, pb, qb, m=m, log=True + ) + + assert nx.to_numpy(res0b).dtype == C1_int.dtype + + +def test_partial_fgw2_gradients(): + n_samples = 20 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=False, device=device) + q1 = torch.tensor(q, requires_grad=False, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.gromov.partial_fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + # full gradients with alpha + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) + + val = ot.gromov.partial_fused_gromov_wasserstein2( + M1, C11, C12, p1, q1, alpha=alpha + ) + + val.backward() + + assert val.device == p1.device + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_partial_gromov_wasserstein(nx): diff --git a/test/test_solvers.py b/test/test_solvers.py index 82a402df1..f6077e005 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -431,8 +431,6 @@ def test_solve_gromov_not_implemented(nx): # detect partial not implemented and error detect in value with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, unbalanced_type="partial", unbalanced=1.5) - with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, M, unbalanced_type="partial", unbalanced=0.5) with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type="partial", unbalanced=1.5) From 1de4196f6fc9a8c3daf970f4d05765c23d097ddc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 14 Nov 2024 01:34:22 +0100 Subject: [PATCH 05/16] complete tests + solve_gromov --- test/gromov/test_partial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py index 1834de298..8b6205683 100644 --- a/test/gromov/test_partial.py +++ b/test/gromov/test_partial.py @@ -7,9 +7,9 @@ # License: MIT License import numpy as np -import torch import scipy as sp import ot +from ot.backend import torch import pytest From d8828f7c645dd9ddbf3fba1272772e88ec9cb73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 14 Nov 2024 01:59:05 +0100 Subject: [PATCH 06/16] release --- CONTRIBUTORS.md | 3 ++- RELEASES.md | 5 +++++ ot/gromov/_partial.py | 1 - 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 093137e2b..2ed6f59a0 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,7 +41,8 @@ The contributors to this library are: * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, + semi-relaxed FGW, quantized FGW, partial FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters, GMMOT) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) diff --git a/RELEASES.md b/RELEASES.md index 63d77bf19..11f4679b0 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -2,6 +2,11 @@ ## 0.9.6dev +#### New features +- Implement CG solvers for partial FGW (PR #687) + +#### Closed issues + ## 0.9.5 diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index 0c8ec4794..c6837f1d3 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -1054,7 +1054,6 @@ def solve_partial_gromov_linesearch( a = reg * cost_deltaG # formula to check for partial FGW b = nx.sum(M * deltaG) + reg * nx.sum(df_G * deltaG) - alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) From 94a5e37a90fc95c40499bcd39944442c9df906ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 27 Nov 2024 14:58:43 +0100 Subject: [PATCH 07/16] partial entropic fgw solvers --- ot/gromov/__init__.py | 4 + ot/gromov/_partial.py | 377 ++++++++++++++++++++++++++++++++++++++++++ ot/solvers.py | 44 ++++- 3 files changed, 419 insertions(+), 6 deletions(-) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index b7520e099..4c470b8de 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -107,6 +107,8 @@ solve_partial_gromov_linesearch, entropic_partial_gromov_wasserstein, entropic_partial_gromov_wasserstein2, + entropic_partial_fused_gromov_wasserstein, + entropic_partial_fused_gromov_wasserstein2, ) @@ -180,4 +182,6 @@ "solve_partial_gromov_linesearch", "entropic_partial_gromov_wasserstein", "entropic_partial_gromov_wasserstein2", + "entropic_partial_fused_gromov_wasserstein", + "entropic_partial_fused_gromov_wasserstein2", ] diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index c6837f1d3..fdfbba951 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -1433,3 +1433,380 @@ def entropic_partial_gromov_wasserstein2( return log_gw["partial_gw_dist"], log_gw else: return log_gw["partial_gw_dist"] + + +def entropic_partial_fused_gromov_wasserstein( + M, + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + alpha=0.5, + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): + r""" + Returns the entropic partial Fused Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and + :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: quadratic loss function + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the FGW problem has been proposed in + :ref:`[24] ` and the + partial GW in :ref:`[29] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float, optional. Default is 1. + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0 : array-like, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + Returns + ------- + :math: `gamma` : (dim_a, dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + + .. _references-entropic-partial-fused-gromov-wasserstein: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + + See Also + -------- + ot.gromov.partial_fused_gromov_wasserstein: exact Partial Fused Gromov-Wasserstein + """ + + arr = [M, C1, C2, G0] + if p is not None: + p = list_to_array(p) + arr.append(p) + if q is not None: + q = list_to_array(q) + arr.append(q) + + nx = get_backend(*arr) + + if p is None: + p = nx.ones(C1.shape[0], type_as=C1) / C1.shape[0] + if q is None: + q = nx.ones(C2.shape[0], type_as=C2) / C2.shape[0] + + if m is None: + m = min(nx.sum(p), nx.sum(q)) + elif m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") + elif m > min(nx.sum(p), nx.sum(q)): + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) + + if G0 is None: + G0 = ( + nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + + else: + # Check marginals of G0 + assert nx.any(nx.sum(G0, 1) <= p) + assert nx.any(nx.sum(G0, 0) <= q) + + if symmetric is None: + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) + + # Setup gradient computation + fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) + fC2t = fC2.T + if not symmetric: + fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T + + ones_p = nx.ones(p.shape[0], type_as=p) + ones_q = nx.ones(q.shape[0], type_as=q) + + def f(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwloss(constC1 + constC2, hC1, hC2, G, nx) + ( + 1 - alpha + ) * nx.sum(G * M) + + if symmetric: + + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwggrad(constC1 + constC2, hC1, hC2, G, nx) + ( + 1 - alpha + ) * nx.sum(G * M) + else: + + def df(G): + pG = nx.sum(G, 1) + qG = nx.sum(G, 0) + constC1 = nx.outer(nx.dot(fC1, pG), ones_q) + constC2 = nx.outer(ones_p, nx.dot(qG, fC2t)) + constC1t = nx.outer(nx.dot(fC1t, pG), ones_q) + constC2t = nx.outer(ones_p, nx.dot(qG, fC2)) + + return 0.5 * alpha * ( + gwggrad(constC1 + constC2, hC1, hC2, G, nx) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx) + ) + (1 - alpha) * nx.sum(G * M) + + cpt = 0 + err = 1 + + loge = {"err": []} + + while err > tol and cpt < numItermax: + Gprev = G0 + M_entr = df(G0) + G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m) + if cpt % 10 == 0: # to speed up the computations + err = np.linalg.norm(G0 - Gprev) + if log: + loge["err"].append(err) + if verbose: + if cpt % 200 == 0: + print( + "{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss") + + "\n" + + "-" * 31 + ) + print("{:5d}|{:8e}|{:8e}".format(cpt, err, f(G0))) + + cpt += 1 + + if log: + loge["partial_fgw_dist"] = f(G0) + return G0, loge + else: + return G0 + + +def entropic_partial_fused_gromov_wasserstein2( + M, + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + alpha=0.5, + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): + r""" + Returns the entropic partial Fused Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and + :math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise + distance matrix :math:`\mathbf{M}` between node feature matrices. + + The function solves the following optimization problem: + + .. math:: + PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + + where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: Loss function to account for the misfit between the similarity matrices. + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported + + The formulation of the FGW problem has been proposed in + :ref:`[24] ` and the + partial GW in :ref:`[29] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + reg: float + entropic regularization parameter + m : float, optional + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss'. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0 : ndarray, shape (ns, nt), optional + Initialization of the transportation matrix + numItermax : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + log : bool, optional + return log if True + verbose : bool, optional + Print information along iterations + + + Returns + ------- + partial_fgw_dist: float + Partial Entropic Fused Gromov-Wasserstein discrepancy + log : dict + log dictionary returned only if `log` is `True` + + .. _references-entropic-partial-fused-gromov-wasserstein2: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. + """ + nx = get_backend(M, C1, C2) + + T, log_pfgw = entropic_partial_fused_gromov_wasserstein( + M, + C1, + C2, + p, + q, + reg, + m, + loss_fun, + alpha, + G0, + numItermax, + tol, + symmetric, + True, + verbose, + ) + + log_pfgw["T"] = T + + # setup for ot.solve_gromov + lin_term = nx.sum(T * M) + log_pfgw["quad_loss"] = log_pfgw["partial_fgw_dist"] - (1 - alpha) * lin_term + log_pfgw["lin_loss"] = lin_term * (1 - alpha) + + if log: + return log_pfgw["partial_fgw_dist"], log_pfgw + else: + return log_pfgw["partial_fgw_dist"] diff --git a/ot/solvers.py b/ot/solvers.py index 508f248d5..a633dd7b3 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -25,6 +25,7 @@ partial_gromov_wasserstein2, partial_fused_gromov_wasserstein2, entropic_partial_gromov_wasserstein2, + entropic_partial_fused_gromov_wasserstein2, ) from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport @@ -974,7 +975,7 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None: # Partial Gromov-Wasserstein problem + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError("Partial GW mass given in reg is too large")) @@ -1204,7 +1205,7 @@ def solve_gromov( value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None: # Partial Gromov-Wasserstein problem + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError("Partial GW mass given in reg is too large")) @@ -1214,7 +1215,7 @@ def solve_gromov( if tol is None: tol = 1e-7 - value_quad, log = entropic_partial_gromov_wasserstein2( + value_noreg, log = entropic_partial_gromov_wasserstein2( Ca, Cb, a, @@ -1230,12 +1231,43 @@ def solve_gromov( verbose=verbose, ) - value_quad = value + value_quad = value_noreg plan = log["T"] # potentials = (log['u'], log['v']) TODO - + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - raise (NotImplementedError("Partial entropic FGW not implemented yet")) + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError("Partial FGW mass given in reg is too large")) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_noreg, log = entropic_partial_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + reg=reg, + loss_fun=loss_fun, + alpha=alpha, + m=unbalanced, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # unbalanced AND regularized OT raise ( From 4c9717723872807b73c0c2765abfdac259e902d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 27 Nov 2024 15:16:38 +0100 Subject: [PATCH 08/16] add tests --- test/gromov/test_partial.py | 163 ++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py index 8b6205683..3b133242c 100644 --- a/test/gromov/test_partial.py +++ b/test/gromov/test_partial.py @@ -49,6 +49,16 @@ def test_raise_errors(): with pytest.raises(ValueError): ot.gromov.partial_fused_gromov_wasserstein(M, M, M, p, q, m=-1, log=True) + with pytest.raises(ValueError): + ot.gromov.entropic_partial_fused_gromov_wasserstein( + M, M, M, p, q, m=2, log=True + ) + + with pytest.raises(ValueError): + ot.gromov.entropic_partial_fused_gromov_wasserstein( + M, M, M, p, q, m=-1, log=True + ) + def test_partial_gromov_wasserstein(nx): rng = np.random.RandomState(42) @@ -585,3 +595,156 @@ def test_entropic_partial_gromov_wasserstein(nx): C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False ) np.testing.assert_allclose(w0, w0_val, rtol=1e-8) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_partial_fused_gromov_wasserstein(nx): + rng = np.random.RandomState(42) + n_samples = 20 # nb samples + n_noise = 10 # nb of samples (noise) + + p = ot.unif(n_samples + n_noise) + psub = ot.unif(n_samples - 5 + n_noise) + q = ot.unif(n_samples + n_noise) + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([0, 0, 0]) + cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # clean samples + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=rng) + P = sp.linalg.sqrtm(cov_t) + xt = rng.randn(n_samples, 3).dot(P) + mu_t + # add noise + xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0) + xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0) + xt2 = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + F1 = xs + + C1sub = ot.dist(xs[5:], xs[5:]) + F1sub = xs[5:] + + C2 = ot.dist(xt, xt) + F2 = xs + + C3 = ot.dist(xt2, xt2) + F3 = xt2 + + M11sub = ot.dist(F1, F1sub) + M12 = ot.dist(F1, F2) + M13 = ot.dist(F1, F3) + + m = 2.0 / 3.0 + + M11subb, M12b, M13b, C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy( + M11sub, M12, M13, C1, C1sub, C2, C3, p, psub, q + ) + + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0b = nx.from_numpy(G0) + + # check consistency across backends and stability w.r.t loss/marginals/sym + list_sym = [True, None] + for i, loss_fun in enumerate(["square_loss", "kl_loss"]): + res, log = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13, + C1, + C3, + p=p, + q=None, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=None, + log=True, + symmetric=list_sym[i], + verbose=True, + ) + + resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13b, + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=G0b, + log=True, + symmetric=False, + verbose=True, + ) + + resb_ = nx.to_numpy(resb) + try: # some instability can occur with kl. to investigate further. + np.testing.assert_allclose(res, resb_, rtol=1e-4) + except AssertionError: + pass + + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= q) # cf convergence wasserstein + + # tests with m is None + res = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13, + C1, + C3, + p=p, + q=None, + reg=1e4, + G0=None, + log=False, + symmetric=list_sym[i], + verbose=True, + ) + + resb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M13b, + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + G0=None, + log=False, + symmetric=False, + verbose=True, + ) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + np.testing.assert_allclose(np.sum(res), 1.0, rtol=1e-4) + + # tests with different number of samples across spaces + m = 0.5 + res, log = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M11sub, C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True + ) + + resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein( + M11subb, C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True + ) + + resb_ = nx.to_numpy(resb) + np.testing.assert_allclose(res, resb_, rtol=1e-4) + assert np.all(res.sum(1) <= p) # cf convergence wasserstein + assert np.all(res.sum(0) <= psub) # cf convergence wasserstein + np.testing.assert_allclose(np.sum(res), m, rtol=1e-4) + + # tests for pGW2 + for loss_fun in ["square_loss", "kl_loss"]: + w0, log0 = ot.gromov.entropic_partial_fused_gromov_wasserstein2( + M12, C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True + ) + w0_val = ot.gromov.entropic_partial_fused_gromov_wasserstein2( + M12b, C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False + ) + np.testing.assert_allclose(w0, w0_val, rtol=1e-8) From a61aa32067156f06ff25f3b87c69be87257718dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 27 Nov 2024 15:56:15 +0100 Subject: [PATCH 09/16] complete solve_gromov --- ot/solvers.py | 3 ++- test/test_solvers.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index 4b5cf2084..96794d9cd 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -863,7 +863,8 @@ def solve_gromov( if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed" + "semirelaxed", + "partial", ]: # Exact balanced OT if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver diff --git a/test/test_solvers.py b/test/test_solvers.py index 85852aca6..a0c1d7c43 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -518,6 +518,10 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, unbalanced_type="partial", unbalanced=1.5) with pytest.raises(ValueError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type="partial", unbalanced=1.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, M, unbalanced_type="partial", unbalanced=1.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, M, reg=1, unbalanced_type="partial", unbalanced=1.5) def test_solve_sample(nx): From 0c16596aa606f492ca92e079495162584f173e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 27 Nov 2024 16:02:29 +0100 Subject: [PATCH 10/16] update --- README.md | 2 +- RELEASES.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 89e16b508..ba674d0c8 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ POT provides the following generic OT solvers (links to examples): * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] -* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] +* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and Partial Fused Gromov-Wasserstein (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] diff --git a/RELEASES.md b/RELEASES.md index c2accd8bf..1120a5f9d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features - Implement CG solvers for partial FGW (PR #687) - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) +- Implement projected gradient descent solvers for entropic partial FGW (PR #702) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From e5c4711e1c98e160276774964fcd5f46fe0822fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 01:51:08 +0100 Subject: [PATCH 11/16] fix solvers --- examples/gromov/plot_barycenter_fgw.py | 10 +- examples/gromov/plot_partial_fgw.py | 380 ++++++++++++++++++ .../plot_partial_wass_and_gromov.py | 5 +- ot/gromov/_partial.py | 8 +- ot/solvers.py | 42 +- 5 files changed, 428 insertions(+), 17 deletions(-) create mode 100644 examples/gromov/plot_partial_fgw.py diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 865c1e71a..b51b4e1ff 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -91,7 +91,7 @@ def build_noisy_circular_graph( g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) else: @@ -107,7 +107,7 @@ def build_noisy_circular_graph( if i == N - 1: g.add_edge(i, 1) g.add_edge(N, 0) - noise = float(np.random.normal(mu, sigma, 1)) + noise = np.random.normal(mu, sigma, 1)[0] if with_noise: g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) else: @@ -157,7 +157,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): plt.subplot(3, 3, i + 1) g = X0[i] pos = nx.kamada_kawai_layout(g) - nx.draw( + nx.draw_networkx( g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), @@ -173,7 +173,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances -Cs = [shortest_path(nx.adjacency_matrix(x).todense()) for x in X0] +Cs = [shortest_path(nx.adjacency_matrix(x).toarray()) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] Ys = [ np.array([v for (k, v) in nx.get_node_attributes(x, "attr_name").items()]).reshape( @@ -199,7 +199,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # %% pos = nx.kamada_kawai_layout(bary) -nx.draw( +nx.draw_networkx( bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False ) plt.suptitle("Barycenter", fontsize=20) diff --git a/examples/gromov/plot_partial_fgw.py b/examples/gromov/plot_partial_fgw.py new file mode 100644 index 000000000..87489ee46 --- /dev/null +++ b/examples/gromov/plot_partial_fgw.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot partial FGW for subgraph matching +================================= + +This example illustrates the computation of partial (Fused) Gromov-Wasserstein +divergences for subgraph matching tasks [18, 29]. + +[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# %% load libraries +import numpy as np +import pylab as pl +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import ( + partial_gromov_wasserstein, + entropic_partial_gromov_wasserstein, + partial_fused_gromov_wasserstein, + entropic_partial_fused_gromov_wasserstein, +) +from ot import unif, dist +# %% Graph generation and visualization functions + + +def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0): + """Create a noisy circular graph""" + # create clean circle + np.random.seed(random_seed) + g = nx.Graph() + g.add_nodes_from(np.arange(n_clean + n_noise)) + for i in range(n_clean): + g.add_node(i, weight=math.sin(2 * i * math.pi / n_clean)) + if i == (n_clean - 1): + g.add_edge(i, 0) + else: + g.add_edge(i, i + 1) + # add nodes out of the circle as structure noise + if n_noise > 0: + noisy_nodes = np.random.choice(np.arange(n_clean), n_noise) + for i, j in enumerate(noisy_nodes): + g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean)) + g.add_edge(i + n_clean, j) + return g + + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap="viridis") + cpick.set_array([]) + val_map = {} + for k, v in nx.get_node_attributes(nx_graph, "weight").items(): + val_map[k] = cpick.to_rgba(v) + colors = [] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + + +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, +): + if pos is None: + pos = nx.kamada_kawai_layout(G) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + nx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + nx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=nodes_size[node], + alpha=1, + node_color=node_color, + ) + return pos + + +def draw_transp_colored( + G1, + C1, + G2, + C2, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + color_features=False, +): + if color_features: + nodes_color_part1 = graph_colors(G1, vmin=-1, vmax=1) + nodes_color_part2 = graph_colors(G2, vmin=-1, vmax=1) + else: + nodes_color_part1 = C1.shape[0] * ["C0"] + nodes_color_part2 = C2.shape[0] * ["C0"] + + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + ) + T_max = T.max() + for k1, v1 in pos1.items(): + for k2, v2 in pos2.items(): + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.8, + alpha=0.5 * T[k1, k2] / T_max, + color=nodes_color_part1[k1], + ) + return pos1, pos2 + + +# %% +############################################################################## +# Generate and visualize data +# ------------- + +# We build a clean circular graph that will be matched to a noisy circular graph. + +clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0) + +noisy_graph = build_noisy_circular_graph(n_clean=15, n_noise=5) + +graphs = [clean_graph, noisy_graph] +list_pos = [] +pl.figure(figsize=(6, 3)) +for i in range(2): + pl.subplot(1, 2, i + 1) + g = graphs[i] + if i == 0: + pl.title("clean graph", fontsize=16) + else: + pl.title("noisy graph", fontsize=16) + pos = nx.kamada_kawai_layout(g) + list_pos.append(pos) + nx.draw_networkx( + g, + pos=pos, + node_color=graph_colors(g, vmin=-1, vmax=1), + with_labels=False, + node_size=100, + ) +pl.show() + +############################################################################## +# Partial (Entropic) Gromov-Wasserstein computation and visualization +# ---------------------- + +# Adjacency matrices are compared using both exact and entropic partial GW +# discarding for now node features +Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs] +ps = [unif(C.shape[0]) for C in Cs] + +# provide an informative initialization for visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], m=m, G0=G0, symmetric=True, log=True +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_gromov_wasserstein( + Cs[0], Cs[1], ps[0], ps[1], reg=0.01, m=m, G0=G0, symmetric=True, log=True +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_gw_dist"], 3), + np.round(logent["partial_gw_dist"], 3), +] +list_dist_str = ["pGW", "pGW_e"] +pl.figure(2, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + ) + +pl.tight_layout() +pl.show() + +############################################################################## +# Partial (Entropic) Fused Gromov-Wasserstein computation and visualization +# ---------------------- + +# Add now node features compared using pairwise euclidean distance +# to illustrate partial FGW computation with trade-off parameter alpha=0.5 +Ys = [ + np.array([v for (k, v) in nx.get_node_attributes(G, "weight").items()]).reshape( + -1, 1 + ) + for G in graphs +] +M = dist(Ys[0], Ys[1]) +# provide an informative initialization for visualization +m = 3.0 / 4.0 +partial_id = np.zeros((15, 20)) +partial_id[:15, :15] = np.eye(15) / 15.0 +G0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2 + +# compute exact partial GW +T, log = partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# compute entropic partial GW leading to dense transport plans +Tent, logent = entropic_partial_fused_gromov_wasserstein( + M, + Cs[0], + Cs[1], + ps[0], + ps[1], + reg=0.01, + alpha=0.5, + m=m, + G0=G0, + symmetric=True, + log=True, +) + +# Plot matchings +list_T = [T, Tent] +list_dist = [ + np.round(log["partial_fgw_dist"], 3), + np.round(logent["partial_fgw_dist"], 3), +] +list_dist_str = ["pFGW", "pFGW_e"] + +pl.figure(3, figsize=(10, 3)) +pl.clf() +for i in range(2): + pl.subplot(1, 2, i + 1) + pl.axis("off") + pl.title( + r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + % (list_dist_str[i], list_dist[i]), + fontsize=14, + ) + + p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( + clean_graph, + Cs[0], + noisy_graph, + Cs[1], + p1=None, + p2=p2, + T=list_T[i], + shiftx=3, + node_size=50, + color_features=True, + ) + +pl.tight_layout() +pl.show() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 5ccc197d6..23a5f96a2 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -5,7 +5,10 @@ ================================================== This example is designed to show how to use the Partial (Gromov-)Wasserstein -distance computation in POT. +distance computation in POT [29]. + +[29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal +Transport with Applications on Positive-Unlabeled Learning". NeurIPS. """ # Author: Laetitia Chapel diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index fdfbba951..5a069fdaf 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -1173,7 +1173,7 @@ def entropic_partial_gromov_wasserstein( Returns ------- - :math: `gamma` : (dim_a, dim_b) ndarray + :math: `gamma` : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1461,7 +1461,7 @@ def entropic_partial_fused_gromov_wasserstein( The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) @@ -1530,7 +1530,7 @@ def entropic_partial_fused_gromov_wasserstein( Returns ------- - :math: `gamma` : (dim_a, dim_b) ndarray + :math: `gamma` : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1693,7 +1693,7 @@ def entropic_partial_fused_gromov_wasserstein2( The function solves the following optimization problem: .. math:: - PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) diff --git a/ot/solvers.py b/ot/solvers.py index 96794d9cd..decf6177e 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1002,8 +1002,17 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1074,7 +1083,8 @@ def solve_gromov( else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed" + "semirelaxed", + "partial", ]: # Balanced regularized OT if reg_type.lower() in ["entropy"] and ( M is None or alpha == 1 @@ -1232,8 +1242,17 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial GW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: @@ -1262,8 +1281,17 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial FGW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: From 8e79b24932d5e3382adf5506d6f05140d1dd2ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 02:08:05 +0100 Subject: [PATCH 12/16] fix solvers --- test/test_solvers.py | 123 ++++++++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 48 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..b1bd097a3 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -432,55 +432,82 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - sol0 = ot.solve_gromov( - Ca, - Cb, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - sol0_fgw = ot.solve_gromov( - Ca, - Cb, - M, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - # solve in backend - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + if unbalanced_type == "partial" and unbalanced is None: + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + with pytest.raises(ValueError): + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + with pytest.raises(ValueError): + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - solx.value_quad - - assert_allclose_sol(sol0, solx) - assert_allclose_sol(sol0_fgw, solx_fgw) + else: + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented") From 4342fb006491a9845f4c20c9174be8b50501da0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 02:14:40 +0100 Subject: [PATCH 13/16] fix solvers --- ot/solvers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index decf6177e..5f8f65870 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1040,9 +1040,17 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial FGW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: max_iter = 1000 From a5c011f468645cd59149a96d9ceb59e9d6b8b9c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 13:47:57 +0100 Subject: [PATCH 14/16] improve example and doc --- examples/gromov/plot_partial_fgw.py | 51 +++++++++++++------ ot/gromov/_partial.py | 76 ++++++++++++++--------------- 2 files changed, 74 insertions(+), 53 deletions(-) diff --git a/examples/gromov/plot_partial_fgw.py b/examples/gromov/plot_partial_fgw.py index 87489ee46..cd6976074 100644 --- a/examples/gromov/plot_partial_fgw.py +++ b/examples/gromov/plot_partial_fgw.py @@ -1,11 +1,19 @@ # -*- coding: utf-8 -*- -""" +r""" ================================= Plot partial FGW for subgraph matching ================================= This example illustrates the computation of partial (Fused) Gromov-Wasserstein -divergences for subgraph matching tasks [18, 29]. +divergences for subgraph matching tasks, using the exact formulation $p(F)GW$ and +the entropically regularized one $p(F)GW_e$ [18, 29]. + +We first create a clean circular graph of 15 nodes with node features correlated with +node positions on the unit circle, and a noisy version where 5 nodes out of the +circle are added. Then knowing the proportion of clean samples in the target graph +$m=3/4$, we show how to identify them using : + - The partial GW matching and its entropic counterpart, omitting node features. + - The partial Fused GW matching and its entropic counterpart. [18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas @@ -20,6 +28,8 @@ # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + # %% load libraries import numpy as np import pylab as pl @@ -35,7 +45,10 @@ entropic_partial_fused_gromov_wasserstein, ) from ot import unif, dist -# %% Graph generation and visualization functions + +############################################################################## +# Utils for generation and visualization +# ------------- def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0): @@ -56,6 +69,7 @@ def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0): for i, j in enumerate(noisy_nodes): g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean)) g.add_edge(i + n_clean, j) + g.add_edge(i + n_clean, (j + 1) % n_clean) return g @@ -138,11 +152,15 @@ def draw_graph( scaled_Gweights = Gweights / (0.5 * Gweights.max()) nodes_size = node_size * scaled_Gweights for node, node_color in enumerate(nodes_color_part): + if nodes_size[node] == 0: + local_node_size = 0 + else: + local_node_size = max(0.1 * node_size, nodes_size[node]) nx.draw_networkx_nodes( G, pos, nodelist=[node], - node_size=nodes_size[node], + node_size=local_node_size, alpha=1, node_color=node_color, ) @@ -198,17 +216,15 @@ def draw_transp_colored( [pos1[k1][1], pos2[k2][1]], "-", lw=0.8, - alpha=0.5 * T[k1, k2] / T_max, + alpha=max(0.05, 0.8 * T[k1, k2] / T_max), color=nodes_color_part1[k1], ) return pos1, pos2 -# %% ############################################################################## # Generate and visualize data # ------------- - # We build a clean circular graph that will be matched to a noisy circular graph. clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0) @@ -239,13 +255,16 @@ def draw_transp_colored( ############################################################################## # Partial (Entropic) Gromov-Wasserstein computation and visualization # ---------------------- - # Adjacency matrices are compared using both exact and entropic partial GW -# discarding for now node features +# discarding for now node features. +# Then for illustration, the node sizes are proportional to their optimized masses +# and the intensity of the link between two nodes across graphs is set proportionally +# to the corresponding transported mass. + Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs] ps = [unif(C.shape[0]) for C in Cs] -# provide an informative initialization for visualization +# provide an informative initialization for better visualization m = 3.0 / 4.0 partial_id = np.zeros((15, 20)) partial_id[:15, :15] = np.eye(15) / 15.0 @@ -268,18 +287,20 @@ def draw_transp_colored( np.round(logent["partial_gw_dist"], 3), ] list_dist_str = ["pGW", "pGW_e"] + pl.figure(2, figsize=(10, 3)) pl.clf() for i in range(2): pl.subplot(1, 2, i + 1) pl.axis("off") pl.title( - r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + r"$%s(\mathbf{C_1},\mathbf{p_1}^\star,\mathbf{C_2},\mathbf{p_2}^\star) =%s$" % (list_dist_str[i], list_dist[i]), fontsize=14, ) p2 = list_T[i].sum(0) + pos1, pos2 = draw_transp_colored( clean_graph, Cs[0], @@ -298,9 +319,9 @@ def draw_transp_colored( ############################################################################## # Partial (Entropic) Fused Gromov-Wasserstein computation and visualization # ---------------------- - -# Add now node features compared using pairwise euclidean distance +# We add now node features compared using pairwise euclidean distance # to illustrate partial FGW computation with trade-off parameter alpha=0.5 + Ys = [ np.array([v for (k, v) in nx.get_node_attributes(G, "weight").items()]).reshape( -1, 1 @@ -308,7 +329,7 @@ def draw_transp_colored( for G in graphs ] M = dist(Ys[0], Ys[1]) -# provide an informative initialization for visualization +# provide an informative initialization for better visualization m = 3.0 / 4.0 partial_id = np.zeros((15, 20)) partial_id[:15, :15] = np.eye(15) / 15.0 @@ -357,7 +378,7 @@ def draw_transp_colored( pl.subplot(1, 2, i + 1) pl.axis("off") pl.title( - r"$%s(\mathbf{C_1},\mathbf{p_1},\mathbf{C_2}) =%s$" + r"$%s(\mathbf{C_1},\mathbf{p_1}^\star,\mathbf{C_2}, \mathbf{p_2}^\star) =%s$" % (list_dist_str[i], list_dist[i]), fontsize=14, ) diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index 5a069fdaf..994241e93 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -45,7 +45,7 @@ def partial_gromov_wasserstein( .. math:: \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -332,7 +332,7 @@ def partial_gromov_wasserstein2( .. math:: \mathbf{PGW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -524,7 +524,7 @@ def partial_fused_gromov_wasserstein( .. math:: \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -812,7 +812,7 @@ def partial_fused_gromov_wasserstein2( .. math:: \mathbf{PFGW}_{\alpha} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -1088,18 +1088,18 @@ def entropic_partial_gromov_wasserstein( The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + \mathbf{T} = \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + T_{i,j} T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1109,7 +1109,7 @@ def entropic_partial_gromov_wasserstein( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the GW problem has been proposed in @@ -1173,7 +1173,7 @@ def entropic_partial_gromov_wasserstein( Returns ------- - :math: `gamma` : ndarray, shape (dim_a, dim_b) + T : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1327,18 +1327,18 @@ def entropic_partial_gromov_wasserstein2( The function solves the following optimization problem: .. math:: - PGW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, - \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + PGW = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l}) + T_{i,j}T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1347,7 +1347,7 @@ def entropic_partial_gromov_wasserstein2( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: Loss function to account for the misfit between the similarity matrices. - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the GW problem has been proposed in @@ -1461,18 +1461,18 @@ def entropic_partial_fused_gromov_wasserstein( The function solves the following optimization problem: .. math:: - \gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F - + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + \mathbf{T} = \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) + T_{i,j} T_{k,l} + \mathrm{reg} \Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1483,7 +1483,7 @@ def entropic_partial_fused_gromov_wasserstein( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the FGW problem has been proposed in @@ -1530,7 +1530,7 @@ def entropic_partial_fused_gromov_wasserstein( Returns ------- - :math: `gamma` : ndarray, shape (dim_a, dim_b) + T : ndarray, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -1693,18 +1693,18 @@ def entropic_partial_fused_gromov_wasserstein2( The function solves the following optimization problem: .. math:: - PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F - + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + PGW = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} + + \mathrm{reg} \cdot\Omega(\mathbf{T}) .. math:: - s.t. \ \gamma &\geq 0 + s.t. \ \mathbf{T} &\geq 0 - \gamma \mathbf{1} &\leq \mathbf{a} + \mathbf{T} \mathbf{1} &\leq \mathbf{a} - \gamma^T \mathbf{1} &\leq \mathbf{b} + \mathbf{T}^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \mathbf{T}^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -1714,7 +1714,7 @@ def entropic_partial_fused_gromov_wasserstein2( - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: Loss function to account for the misfit between the similarity matrices. - :math:`\Omega` is the entropic regularization term, - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + :math:`\Omega(\mathbf{T})=\sum_{i,j} T_{i,j}\log(T_{i,j})` - `m` is the amount of mass to be transported The formulation of the FGW problem has been proposed in From 95cd2e305bdc55d16212455f4e46821672adb741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 13:58:16 +0100 Subject: [PATCH 15/16] update readme --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 3644605ce..45fddba1e 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,7 @@ POT provides the following generic OT solvers (links to examples): * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] -* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and Partial Fused Gromov-Wasserstein (exact [29] and entropic [3] - formulations). +* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] From 5ab38c59cd2cd9952bb419b9f79d85a0aa8e1694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 17:35:47 +0100 Subject: [PATCH 16/16] reset solve_gromov behavior for unbalanced=none --- ot/solvers.py | 58 ++++++++------------ test/test_solvers.py | 123 +++++++++++++++++-------------------------- 2 files changed, 71 insertions(+), 110 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 5f8f65870..4b30eced7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -32,6 +32,9 @@ from .lowrank import lowrank_sinkhorn from .optim import cg +import warnings + + lst_method_lazy = [ "1d", "gaussian", @@ -658,7 +661,8 @@ def solve_gromov( ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT), Not implemented yet + (balanced OT). Not implemented yet for "KL" unbalanced penalization + function :math:`U`. Corresponds to the total transport mass for partial OT. unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", "partial", by default "KL" but note that it is not implemented yet. @@ -864,8 +868,14 @@ def solve_gromov( if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in [ "semirelaxed", - "partial", ]: # Exact balanced OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver if max_iter is None: @@ -1002,14 +1012,7 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1040,14 +1043,7 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO else: # partial FGW - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1092,8 +1088,14 @@ def solve_gromov( else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in [ "semirelaxed", - "partial", ]: # Balanced regularized OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if reg_type.lower() in ["entropy"] and ( M is None or alpha == 1 ): # Entropic Gromov-Wasserstein problem @@ -1250,14 +1252,7 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1289,14 +1284,7 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) diff --git a/test/test_solvers.py b/test/test_solvers.py index b1bd097a3..a0c1d7c43 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -432,82 +432,55 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - if unbalanced_type == "partial" and unbalanced is None: - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - - with pytest.raises(ValueError): - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - with pytest.raises(ValueError): - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW - else: - sol0 = ot.solve_gromov( - Ca, - Cb, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - sol0_fgw = ot.solve_gromov( - Ca, - Cb, - M, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - # solve in backend - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - solx.value_quad - - assert_allclose_sol(sol0, solx) - assert_allclose_sol(sol0_fgw, solx_fgw) + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented")