diff --git a/torchbenchmark/models/pytorch_unet/__init__.py b/torchbenchmark/models/pytorch_unet/__init__.py index dfcdee2d15..033737987a 100644 --- a/torchbenchmark/models/pytorch_unet/__init__.py +++ b/torchbenchmark/models/pytorch_unet/__init__.py @@ -6,7 +6,6 @@ from torch import optim from typing import Tuple -torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False from .pytorch_unet.unet import UNet @@ -89,6 +88,7 @@ def jit_callback(self): self.model = torch.jit.script(self.model) def eval(self) -> Tuple[torch.Tensor]: + torch.backends.cudnn.deterministic = True self.model.eval() with torch.no_grad(): with torch.cuda.amp.autocast(enabled=self.args.amp):