Skip to content

Commit

Permalink
[MRG] New API for gromov solvers (#536)
Browse files Browse the repository at this point in the history
* add new API for gromov

* small bug entropic fgw

* f* pep8

* add semirelaxed

* al is working and is tested

* documentation for solve_gromov

* update documentaton and add ârtial

* pep8

* pep8

* better tests + release file

* take comments into account$

* it should work now

* last commets cedric
  • Loading branch information
rflamary authored Oct 24, 2023
1 parent 57eda61 commit a9de7a0
Show file tree
Hide file tree
Showing 11 changed files with 679 additions and 22 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)
+ New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
4 changes: 2 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve
from .solvers import solve, solve_gromov

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -65,7 +65,7 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
'factored_optimal_transport', 'solve', 'solve_gromov',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
1 change: 1 addition & 0 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,7 @@ class NearestBrenierPotential(BaseTransport):
ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data
ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data
"""

def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None):
self.strongly_convex_constant = strongly_convex_constant
self.gradient_lipschitz_constant = gradient_lipschitz_constant
Expand Down
8 changes: 8 additions & 0 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def entropic_gromov_wasserstein2(
learning for graph matching and node embedding. In International
Conference on Machine Learning (ICML), 2019.
"""

T, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)
Expand Down Expand Up @@ -815,12 +816,19 @@ def entropic_fused_gromov_wasserstein2(
(ICML). 2019.
"""

nx = get_backend(M, C1, C2)

T, logv = entropic_fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)

logv['T'] = T

lin_term = nx.sum(T * M)
logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term)
logv['lin_loss'] = lin_term * (1 - alpha)

if log:
return logv['fgw_dist'], logv
else:
Expand Down
9 changes: 7 additions & 2 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
fgw_dist = log_fgw['fgw_dist']
log_fgw['T'] = T

# compute separate terms for gradients and log
lin_term = nx.sum(T * M)
log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term)
log_fgw['lin_loss'] = lin_term * (1 - alpha)
gw_term = log_fgw['quad_loss'] / alpha

if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
Expand All @@ -591,8 +597,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha

fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
Expand Down
6 changes: 5 additions & 1 deletion ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
q = nx.sum(T, 0)
srfgw_dist = log_fgw['srfgw_dist']
log_fgw['T'] = T
log_fgw['lin_loss'] = nx.sum(M * T) * (1 - alpha)
log_fgw['quad_loss'] = srfgw_dist - log_fgw['lin_loss']

if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
Expand Down Expand Up @@ -979,7 +981,9 @@ def df(G):
if log:
qG = nx.sum(G, 0)
marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G)
log['lin_loss'] = nx.sum(M * G) * (1 - alpha)
log['quad_loss'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx)
log['srfgw_dist'] = log['lin_loss'] + log['quad_loss']
return G, log
else:
return G
Expand Down
Loading

0 comments on commit a9de7a0

Please sign in to comment.