Skip to content

Commit

Permalink
add documentation and warnings in (f)gw cg solvers when integers are …
Browse files Browse the repository at this point in the history
…provided (#560)
  • Loading branch information
cedricvincentcuaz authored Nov 7, 2023
1 parent 1ece2d8 commit 1682b60
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 13 deletions.
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)

- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)

## 0.9.1
*August 2023*
Expand Down
57 changes: 46 additions & 11 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License: MIT License

import numpy as np
import warnings


from ..utils import dist, UndefinedParameter, list_to_array
Expand Down Expand Up @@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.
Parameters
----------
Expand Down Expand Up @@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
if q is not None:
arr.append(list_to_array(q))
else:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=C1)
if G0 is not None:
G0_ = G0
arr.append(G0)
Expand Down Expand Up @@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)

if not nx.is_floating_point(C10):
warnings.warn(
"Input structure matrix consists of integer. The transport plan will be "
"casted accordingly, possibly resulting in a loss of precision. "
"If this behaviour is unwanted, please make sure your input "
"structure matrix consists of floating point elements.",
stacklevel=2
)

if log:
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.
Parameters
----------
Expand Down Expand Up @@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
if p is None:
p = unif(C1.shape[0], type_as=C1)
if q is None:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=C1)

T, log_gw = gromov_wasserstein(
C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
Expand Down Expand Up @@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{M}`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.
Parameters
Expand Down Expand Up @@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
if p is not None:
arr.append(list_to_array(p))
else:
p = unif(C1.shape[0], type_as=C1)
p = unif(C1.shape[0], type_as=M)
if q is not None:
arr.append(list_to_array(q))
else:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=M)
if G0 is not None:
G0_ = G0
arr.append(G0)
Expand Down Expand Up @@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
if not nx.is_floating_point(M0):
warnings.warn(
"Input feature matrix consists of integer. The transport plan will be "
"casted accordingly, possibly resulting in a loss of precision. "
"If this behaviour is unwanted, please make sure your input "
"feature matrix consists of floating point elements.",
stacklevel=2
)
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0)
log['u'] = nx.from_numpy(log['u'], type_as=M0)
log['v'] = nx.from_numpy(log['v'], type_as=M0)
return nx.from_numpy(res, type_as=M0), log
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0)


def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
Expand Down Expand Up @@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{M}`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.
Parameters
----------
Expand Down Expand Up @@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',

# init marginals if set as None
if p is None:
p = unif(C1.shape[0], type_as=C1)
p = unif(C1.shape[0], type_as=M)
if q is None:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=M)

T, log_fgw = fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
Expand Down
59 changes: 58 additions & 1 deletion test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,37 @@ def test_asymmetric_gromov(nx):
np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)


def test_gromov_integer_warnings(nx):
n_samples = 10 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()
C1 = C1.astype(np.int32)
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

G = ot.gromov.gromov_wasserstein(
C1, C2, None, q, 'square_loss', G0=G0, verbose=True,
alpha_min=0., alpha_max=1.)
Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True))

# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(G, 0., atol=1e-09)


def test_gromov_dtype_device(nx):
# setup
n_samples = 20 # nb samples
Expand Down Expand Up @@ -1145,7 +1176,7 @@ def test_fgw(nx):


def test_asymmetric_fgw(nx):
n_samples = 50 # nb samples
n_samples = 20 # nb samples
rng = np.random.RandomState(0)
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
idx = np.arange(n_samples)
Expand Down Expand Up @@ -1221,6 +1252,32 @@ def test_asymmetric_fgw(nx):
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)


def test_fgw_integer_warnings(nx):
n_samples = 20 # nb samples
rng = np.random.RandomState(0)
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
idx = np.arange(n_samples)
rng.shuffle(idx)
C2 = C1[idx, :][:, idx]

# add features
F1 = rng.uniform(low=0., high=10, size=(n_samples, 1))
F2 = F1[idx, :]
p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

M = ot.dist(F1, F2).astype(np.int32)
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)

G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(G, 0., atol=1e-06)


def test_fgw2_gradients():
n_samples = 20 # nb samples

Expand Down

0 comments on commit 1682b60

Please sign in to comment.