Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New API for gromov solvers #536

Merged
merged 16 commits into from
Oct 24, 2023
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)
+ New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
4 changes: 2 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve
from .solvers import solve, solve_gromov

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -65,7 +65,7 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
'factored_optimal_transport', 'solve', 'solve_gromov',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
1 change: 1 addition & 0 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,7 @@ class NearestBrenierPotential(BaseTransport):
ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data
ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data
"""

def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None):
self.strongly_convex_constant = strongly_convex_constant
self.gradient_lipschitz_constant = gradient_lipschitz_constant
Expand Down
10 changes: 10 additions & 0 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def entropic_gromov_wasserstein2(
learning for graph matching and node embedding. In International
Conference on Machine Learning (ICML), 2019.
"""

T, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)
Expand Down Expand Up @@ -815,12 +816,21 @@ def entropic_fused_gromov_wasserstein2(
(ICML). 2019.

"""

nx = get_backend(M, C1, C2)

T, logv = entropic_fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)

logv['T'] = T

lin_term = nx.sum(T * M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could avoid this overhead by defining both logv['quad_loss'] and logv['lin_loss'] directly in entropic_fused_gromov_wasserstein if log=True

gw_term = (logv['fgw_dist'] - (1 - alpha) * lin_term) / alpha

logv['quad_loss'] = gw_term * alpha
logv['lin_loss'] = lin_term * (1 - alpha)

if log:
return logv['fgw_dist'], logv
else:
Expand Down
10 changes: 8 additions & 2 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
fgw_dist = log_fgw['fgw_dist']
log_fgw['T'] = T

# compute separate terms for gradients and log
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
rflamary marked this conversation as resolved.
Show resolved Hide resolved

log_fgw['quad_loss'] = gw_term * alpha
log_fgw['lin_loss'] = lin_term * (1 - alpha)

if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
Expand All @@ -591,8 +598,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha

fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
Expand Down
6 changes: 5 additions & 1 deletion ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
q = nx.sum(T, 0)
srfgw_dist = log_fgw['srfgw_dist']
log_fgw['T'] = T
log_fgw['lin_loss'] = nx.sum(M * T) * (1 - alpha)
log_fgw['quad_loss'] = srfgw_dist - log_fgw['lin_loss']

if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
Expand Down Expand Up @@ -979,7 +981,9 @@ def df(G):
if log:
qG = nx.sum(G, 0)
marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G)
log['lin_loss'] = nx.sum(M * G) * (1 - alpha)
log['quad_loss'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx)
log['srfgw_dist'] = log['lin_loss'] + log['quad_loss']
return G, log
else:
return G
Expand Down
Loading
Loading