Skip to content

Commit

Permalink
reset solve_gromov behavior for unbalanced=none
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz committed Jan 6, 2025
1 parent 95cd2e3 commit 5ab38c5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 110 deletions.
58 changes: 23 additions & 35 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from .lowrank import lowrank_sinkhorn
from .optim import cg

import warnings


lst_method_lazy = [
"1d",
"gaussian",
Expand Down Expand Up @@ -658,7 +661,8 @@ def solve_gromov(
``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``)
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT), Not implemented yet
(balanced OT). Not implemented yet for "KL" unbalanced penalization
function :math:`U`. Corresponds to the total transport mass for partial OT.
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed",
"partial", by default "KL" but note that it is not implemented yet.
Expand Down Expand Up @@ -864,8 +868,14 @@ def solve_gromov(
if reg is None or reg == 0: # exact OT
if unbalanced is None and unbalanced_type.lower() not in [
"semirelaxed",
"partial",
]: # Exact balanced OT
if unbalanced_type.lower() in ["partial"]:
warnings.warn(
"Exact balanced OT is computed as `unbalanced=None` even though "
f"unbalanced_type = {unbalanced_type}.",
stacklevel=2,
)

if M is None or alpha == 1: # Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
Expand Down Expand Up @@ -1002,14 +1012,7 @@ def solve_gromov(

elif unbalanced_type.lower() in ["partial"]: # Partial OT
if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem
if unbalanced is None:
raise (
ValueError(
"Partial GW mass given in `unbalanced` must be float and not None"
)
)

elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (
ValueError("Partial GW mass given in `unbalanced` is too large")
)
Expand Down Expand Up @@ -1040,14 +1043,7 @@ def solve_gromov(
# potentials = (log['u'], log['v']) TODO

else: # partial FGW
if unbalanced is None:
raise (
ValueError(
"Partial GW mass given in `unbalanced` must be float and not None"
)
)

elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (
ValueError("Partial GW mass given in `unbalanced` is too large")
)
Expand Down Expand Up @@ -1092,8 +1088,14 @@ def solve_gromov(
else: # regularized OT
if unbalanced is None and unbalanced_type.lower() not in [
"semirelaxed",
"partial",
]: # Balanced regularized OT
if unbalanced_type.lower() in ["partial"]:
warnings.warn(
"Exact balanced OT is computed as `unbalanced=None` even though "
f"unbalanced_type = {unbalanced_type}.",
stacklevel=2,
)

if reg_type.lower() in ["entropy"] and (
M is None or alpha == 1
): # Entropic Gromov-Wasserstein problem
Expand Down Expand Up @@ -1250,14 +1252,7 @@ def solve_gromov(

elif unbalanced_type.lower() in ["partial"]: # Partial OT
if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem
if unbalanced is None:
raise (
ValueError(
"Partial GW mass given in `unbalanced` must be float and not None"
)
)

elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (
ValueError("Partial GW mass given in `unbalanced` is too large")
)
Expand Down Expand Up @@ -1289,14 +1284,7 @@ def solve_gromov(
# potentials = (log['u'], log['v']) TODO
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))
else: # partial FGW
if unbalanced is None:
raise (
ValueError(
"Partial GW mass given in `unbalanced` must be float and not None"
)
)

elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (
ValueError("Partial GW mass given in `unbalanced` is too large")
)
Expand Down
123 changes: 48 additions & 75 deletions test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,82 +432,55 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha
M = np.random.rand(n_samples_s, n_samples_t)

try:
if unbalanced_type == "partial" and unbalanced is None:
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)

with pytest.raises(ValueError):
solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
with pytest.raises(ValueError):
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW
sol0 = ot.solve_gromov(
Ca,
Cb,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
sol0_fgw = ot.solve_gromov(
Ca,
Cb,
M,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

else:
sol0 = ot.solve_gromov(
Ca,
Cb,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
sol0_fgw = ot.solve_gromov(
Ca,
Cb,
M,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

# solve in backend
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)

solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

solx.value_quad

assert_allclose_sol(sol0, solx)
assert_allclose_sol(sol0_fgw, solx_fgw)
# solve in backend
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)

solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

solx.value_quad

assert_allclose_sol(sol0, solx)
assert_allclose_sol(sol0_fgw, solx_fgw)

except NotImplementedError:
pytest.skip("Not implemented")
Expand Down

0 comments on commit 5ab38c5

Please sign in to comment.