Skip to content

Commit

Permalink
add condition in enable_amp to fix eval not work under amp
Browse files Browse the repository at this point in the history
  • Loading branch information
wincent8 committed Apr 11, 2024
1 parent 6caee05 commit 0315d0b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def enable_amp(self):
self.amp_context = lambda: torch.cpu.amp.autocast()
elif self.device == "cuda":
self.amp_context = lambda: torch.cuda.amp.autocast()
if is_staged_train_test(self):
if self.test == "train" and is_staged_train_test(self):
self.forward_contexts.append(self.amp_context)

@property
Expand Down

0 comments on commit 0315d0b

Please sign in to comment.