Skip to content

Commit

Permalink
add warning for not staged training model
Browse files Browse the repository at this point in the history
  • Loading branch information
wincent8 committed Apr 16, 2024
1 parent 0315d0b commit 92071f1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
7 changes: 4 additions & 3 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def apply_decoration_args(
model.add_context(lambda: torch.cuda.amp.autocast(dtype=torch.float16))
elif model.test == "train":
# the model must implement staged train test
assert is_staged_train_test(
model
), f"Expected model implements staged train test (forward, backward, optimizer)."
warnings.warn(
"Usually models do not want to enable amp only in forward path, so expected "
"model to have staged train support."
)
import torch

model.add_context(
Expand Down
14 changes: 12 additions & 2 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,18 @@ def enable_amp(self):
self.amp_context = lambda: torch.cpu.amp.autocast()
elif self.device == "cuda":
self.amp_context = lambda: torch.cuda.amp.autocast()
if self.test == "train" and is_staged_train_test(self):
self.forward_contexts.append(self.amp_context)
if self.test == "eval":
self.add_context(self.amp_context)
elif self.test == "train":
if is_staged_train_test(self):
self.add_context(self.amp_context, TEST_STAGE.FORWARD)
else:
warnings.warn(
"Usually models do not want to enable amp only in forward path, so expected "
"model to have staged train support."
)



@property
def pt2_compilation_time(self) -> Optional[float]:
Expand Down

0 comments on commit 92071f1

Please sign in to comment.