Skip to content

Commit

Permalink
Fix simplify ssim implementation (#2563)
Browse files Browse the repository at this point in the history
* Fixed SSIM issue with variable batch

* Fixed mypy issue

* simplify ssim implementation

* Fix bug

Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
  • Loading branch information
sadra-barikbin and vfdev-5 authored May 4, 2022
1 parent 8f9b080 commit 9810c54
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 24 deletions.
10 changes: 5 additions & 5 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
# Not a tensor because batch size is not known in advance.
self._sum_of_batchwise_ssim = 0.0 # type: Union[float, torch.Tensor]
self._sum_of_ssim = torch.tensor(0.0, device=self._device)
self._num_examples = 0
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)

Expand Down Expand Up @@ -176,11 +175,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
b2 = sigma_pred_sq + sigma_target_sq + self.c2

ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).to(self._device)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(self._device)

self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples")
@sync_all_reduce("_sum_of_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return torch.sum(self._sum_of_batchwise_ssim / self._num_examples) # type: ignore[arg-type]
return self._sum_of_ssim / self._num_examples
58 changes: 39 additions & 19 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,17 @@ def test_invalid_ssim():
ssim.compute()


def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_covariance, device):
atol = 7e-5
@pytest.mark.parametrize("device", ["cpu"] + ["cuda"] if torch.cuda.is_available() else [])
@pytest.mark.parametrize(
"shape, kernel_size, gaussian, use_sample_covariance",
[[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]],
)
def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance):
y_pred = torch.rand(shape, device=device)
y = y_pred * 0.8

sigma = 1.5
data_range = 1.0
ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
ssim.update((y_pred, y))
ignite_ssim = ssim.compute()
Expand All @@ -89,23 +98,34 @@ def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_c

assert isinstance(ignite_ssim, torch.Tensor)
assert ignite_ssim.dtype == torch.float64
assert ignite_ssim.device == torch.device(device)
assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=atol)


def test_ssim():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(8, 3, 224, 224, device=device)
y = y_pred * 0.8
_test_ssim(
y_pred, y, data_range=1.0, kernel_size=7, sigma=1.5, gaussian=False, use_sample_covariance=True, device=device
)
assert ignite_ssim.device.type == torch.device(device).type
assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=7e-5)


def test_ssim_variable_batchsize():
# Checks https://github.com/pytorch/ignite/issues/2532
sigma = 1.5
data_range = 1.0
ssim = SSIM(data_range=data_range, sigma=sigma)

y_preds = [
torch.rand(12, 3, 28, 28),
torch.rand(12, 3, 28, 28),
torch.rand(8, 3, 28, 28),
torch.rand(16, 3, 28, 28),
torch.rand(1, 3, 28, 28),
torch.rand(30, 3, 28, 28),
]
y_true = [v * 0.8 for v in y_preds]

for y_pred, y in zip(y_preds, y_true):
ssim.update((y_pred, y))

y_pred = torch.rand(12, 3, 28, 28, device=device)
y = y_pred * 0.8
_test_ssim(
y_pred, y, data_range=1.0, kernel_size=11, sigma=1.5, gaussian=True, use_sample_covariance=False, device=device
)
out = ssim.compute()
ssim.reset()
ssim.update((torch.cat(y_preds), torch.cat(y_true)))
expected = ssim.compute()
assert torch.allclose(out, expected)


def _test_distrib_integration(device, tol=1e-4):
Expand Down Expand Up @@ -186,7 +206,7 @@ def _test_distrib_accumulator_device(device):
y = y_pred * 0.65
ssim.update((y_pred, y))

dev = ssim._sum_of_batchwise_ssim.device
dev = ssim._sum_of_ssim.device
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"


Expand Down

0 comments on commit 9810c54

Please sign in to comment.