Skip to content

Commit

Permalink
Fix ssim run on cuda 2532 (#2564)
Browse files Browse the repository at this point in the history
* Remove unnecessary code in BaseOutputHandler

Closes #2438

* Add ReduceLROnPlateauScheduler

Closes #1754

* Fix indentation issue

* Fix another indentation issue

* Fix PEP8 related issues

* Fix other PEP8 related issues

* Fix hopefully the last PEP8 related issue

* Fix hopefully the last PEP8 related issue

* Remove ReduceLROnPlateau's specific params and add link to it

Also fix bug in min_lr check

* Fix state_dict bug and add a test

* Update docs

* Add doctest and fix typo

* fix bug to use on cuda

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
sadra-barikbin and vfdev-5 authored May 4, 2022
1 parent 9810c54 commit 315b6b9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_ssim = torch.tensor(0.0, device=self._device)
self._sum_of_ssim = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._num_examples = 0
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)

Expand Down Expand Up @@ -180,7 +180,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return self._sum_of_ssim / self._num_examples
return (self._sum_of_ssim / self._num_examples).item()
8 changes: 3 additions & 5 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,8 @@ def test_ssim(device, shape, kernel_size, gaussian, use_sample_covariance):
use_sample_covariance=use_sample_covariance,
)

assert isinstance(ignite_ssim, torch.Tensor)
assert ignite_ssim.dtype == torch.float64
assert ignite_ssim.device.type == torch.device(device).type
assert np.allclose(ignite_ssim.cpu().numpy(), skimg_ssim, atol=7e-5)
assert isinstance(ignite_ssim, float)
assert np.allclose(ignite_ssim, skimg_ssim, atol=7e-5)


def test_ssim_variable_batchsize():
Expand All @@ -125,7 +123,7 @@ def test_ssim_variable_batchsize():
ssim.reset()
ssim.update((torch.cat(y_preds), torch.cat(y_true)))
expected = ssim.compute()
assert torch.allclose(out, expected)
assert np.allclose(out, expected)


def _test_distrib_integration(device, tol=1e-4):
Expand Down

0 comments on commit 315b6b9

Please sign in to comment.