Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] documentation and warnings in (f)gw cg solvers for integer inputs #560

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)

- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)

## 0.9.1
*August 2023*
Expand Down
57 changes: 46 additions & 11 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License: MIT License

import numpy as np
import warnings


from ..utils import dist, UndefinedParameter, list_to_array
Expand Down Expand Up @@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.

Parameters
----------
Expand Down Expand Up @@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
if q is not None:
arr.append(list_to_array(q))
else:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=C1)
if G0 is not None:
G0_ = G0
arr.append(G0)
Expand Down Expand Up @@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs)

if not nx.is_floating_point(C10):
warnings.warn(
"Input structure matrix consists of integer. The transport plan will be "
"casted accordingly, possibly resulting in a loss of precision. "
"If this behaviour is unwanted, please make sure your input "
"structure matrix consists of floating point elements.",
stacklevel=2
)

if log:
res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
Expand Down Expand Up @@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{C}_1`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.

Parameters
----------
Expand Down Expand Up @@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri
if p is None:
p = unif(C1.shape[0], type_as=C1)
if q is None:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=C1)

T, log_gw = gromov_wasserstein(
C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
Expand Down Expand Up @@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{M}`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.


Parameters
Expand Down Expand Up @@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss',
if p is not None:
arr.append(list_to_array(p))
else:
p = unif(C1.shape[0], type_as=C1)
p = unif(C1.shape[0], type_as=M)
if q is not None:
arr.append(list_to_array(q))
else:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=M)
if G0 is not None:
G0_ = G0
arr.append(G0)
Expand Down Expand Up @@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
if not nx.is_floating_point(M0):
warnings.warn(
"Input feature matrix consists of integer. The transport plan will be "
"casted accordingly, possibly resulting in a loss of precision. "
"If this behaviour is unwanted, please make sure your input "
"feature matrix consists of floating point elements.",
stacklevel=2
)
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0)
log['u'] = nx.from_numpy(log['u'], type_as=M0)
log['v'] = nx.from_numpy(log['v'], type_as=M0)
return nx.from_numpy(res, type_as=M0), log
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0)


def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
Expand Down Expand Up @@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',
which can lead to copy overhead on GPU arrays.
.. note:: All computations in the conjugate gradient solver are done with
numpy to limit memory overhead.
.. note:: This function will cast the computed transport plan to the data
type of the provided input :math:`\mathbf{M}`. Casting to an integer
tensor might result in a loss of precision. If this behaviour is
unwanted, please make sure to provide a floating point input.

Parameters
----------
Expand Down Expand Up @@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss',

# init marginals if set as None
if p is None:
p = unif(C1.shape[0], type_as=C1)
p = unif(C1.shape[0], type_as=M)
if q is None:
q = unif(C2.shape[0], type_as=C2)
q = unif(C2.shape[0], type_as=M)

T, log_fgw = fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
Expand Down
59 changes: 58 additions & 1 deletion test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,37 @@ def test_asymmetric_gromov(nx):
np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)


def test_gromov_integer_warnings(nx):
n_samples = 10 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()
C1 = C1.astype(np.int32)
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

G = ot.gromov.gromov_wasserstein(
C1, C2, None, q, 'square_loss', G0=G0, verbose=True,
alpha_min=0., alpha_max=1.)
Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True))

# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(G, 0., atol=1e-09)


def test_gromov_dtype_device(nx):
# setup
n_samples = 20 # nb samples
Expand Down Expand Up @@ -1145,7 +1176,7 @@ def test_fgw(nx):


def test_asymmetric_fgw(nx):
n_samples = 50 # nb samples
n_samples = 20 # nb samples
rng = np.random.RandomState(0)
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
idx = np.arange(n_samples)
Expand Down Expand Up @@ -1221,6 +1252,32 @@ def test_asymmetric_fgw(nx):
np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)


def test_fgw_integer_warnings(nx):
n_samples = 20 # nb samples
rng = np.random.RandomState(0)
C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples))
idx = np.arange(n_samples)
rng.shuffle(idx)
C2 = C1[idx, :][:, idx]

# add features
F1 = rng.uniform(low=0., high=10, size=(n_samples, 1))
F2 = F1[idx, :]
p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

M = ot.dist(F1, F2).astype(np.int32)
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)

G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
Gb = nx.to_numpy(Gb)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(G, 0., atol=1e-06)


def test_fgw2_gradients():
n_samples = 20 # nb samples

Expand Down
Loading