From 92071f1d15e6f5d3892cff57ea83ac9e7a0a7eb8 Mon Sep 17 00:00:00 2001 From: "Liao, Wei" Date: Tue, 16 Apr 2024 11:13:17 +0800 Subject: [PATCH] add warning for not staged training model --- torchbenchmark/util/extra_args.py | 7 ++++--- torchbenchmark/util/model.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index 78bc5235d0..c9f2b693ef 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -188,9 +188,10 @@ def apply_decoration_args( model.add_context(lambda: torch.cuda.amp.autocast(dtype=torch.float16)) elif model.test == "train": # the model must implement staged train test - assert is_staged_train_test( - model - ), f"Expected model implements staged train test (forward, backward, optimizer)." + warnings.warn( + "Usually models do not want to enable amp only in forward path, so expected " + "model to have staged train support." + ) import torch model.add_context( diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index b1ac813124..3a1d76a84e 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -483,8 +483,18 @@ def enable_amp(self): self.amp_context = lambda: torch.cpu.amp.autocast() elif self.device == "cuda": self.amp_context = lambda: torch.cuda.amp.autocast() - if self.test == "train" and is_staged_train_test(self): - self.forward_contexts.append(self.amp_context) + if self.test == "eval": + self.add_context(self.amp_context) + elif self.test == "train": + if is_staged_train_test(self): + self.add_context(self.amp_context, TEST_STAGE.FORWARD) + else: + warnings.warn( + "Usually models do not want to enable amp only in forward path, so expected " + "model to have staged train support." + ) + + @property def pt2_compilation_time(self) -> Optional[float]: