Skip to content

Commit

Permalink
revert the last two commits
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Muñoz committed Jan 9, 2025
1 parent e8d4de5 commit b0c5ef3
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 100 deletions.
11 changes: 10 additions & 1 deletion ot/bregman/_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def _convolutional_barycenter2d_log(
A = list_to_array(A)

nx = get_backend(A)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)

n_hists, width, height = A.shape

Expand Down Expand Up @@ -478,7 +483,11 @@ def _convolutional_barycenter2d_debiased_log(
A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)

if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
Expand Down
236 changes: 137 additions & 99 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,18 +825,22 @@ def test_wasserstein_bary_2d(nx, method):

# wasserstein
reg = 1e-2
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
bary_wass = nx.to_numpy(
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
)
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
bary_wass = nx.to_numpy(
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)


@pytest.skip_backend("tf")
Expand All @@ -852,23 +856,27 @@ def test_wasserstein_bary_2d_dtype_device(nx, method):

# wasserstein
reg = 1e-2
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)


@pytest.mark.skipif(not tf, reason="tf not installed")
Expand All @@ -886,6 +894,37 @@ def test_wasserstein_bary_2d_device_tf(method):

# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
Expand All @@ -904,32 +943,9 @@ def test_wasserstein_bary_2d_device_tf(method):
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check this only if GPU is available
if len(tf.config.list_physical_devices("GPU")) > 0:
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
# Check this only if GPU is available
if len(tf.config.list_physical_devices("GPU")) > 0:
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")


@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
Expand All @@ -941,18 +957,22 @@ def test_wasserstein_bary_2d_debiased(nx, method):

# wasserstein
reg = 1e-2
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
bary_wass = nx.to_numpy(
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
)
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
else:
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
bary_wass = nx.to_numpy(
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)


@pytest.skip_backend("tf")
Expand All @@ -968,25 +988,31 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method):

# wasserstein
reg = 1e-2
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
Ab, reg, method=method
)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
Ab, reg, method=method
)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(
A, reg, log=True, verbose=True
)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)
# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)


@pytest.mark.skipif(not tf, reason="tf not installed")
Expand All @@ -1004,6 +1030,41 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):

# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
Ab, reg, method=method
)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(
A, reg, log=True, verbose=True
)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
if method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
else:
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
Expand All @@ -1016,29 +1077,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)

# Test that the dtype and device are the same after the computation
nx.assert_same_dtype_device(Ab, bary_wass_b)

# Check that everything happens on the GPU
Ab = nx.from_numpy(A)

# wasserstein
reg = 1e-2
# Compute the barycenter with numpy
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
A, reg, method=method, verbose=True, log=True
)
# Compute the barycenter with the backend
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
# Convert the backend result to numpy, to compare with the numpy result
bary_wass = nx.to_numpy(bary_wass_b)

np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)


def test_unmix(nx):
Expand Down

0 comments on commit b0c5ef3

Please sign in to comment.