From 0315d0b78a0e92f58f7755b35df338ff4be7947f Mon Sep 17 00:00:00 2001 From: "Liao, Wei" Date: Thu, 11 Apr 2024 16:48:37 +0800 Subject: [PATCH] add condition in enable_amp to fix eval not work under amp --- torchbenchmark/util/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index e21871398f..b1ac813124 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -483,7 +483,7 @@ def enable_amp(self): self.amp_context = lambda: torch.cpu.amp.autocast() elif self.device == "cuda": self.amp_context = lambda: torch.cuda.amp.autocast() - if is_staged_train_test(self): + if self.test == "train" and is_staged_train_test(self): self.forward_contexts.append(self.amp_context) @property