From 25338db930e84d977d5fb58c198462c1f57581bb Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 21 Feb 2024 11:53:46 -0800 Subject: [PATCH] Prioritize user-defined `train()` function over the staged `forward()` (#2174) Summary: It gives more user-friendly error messages upon unimplemented train tests. Fixes https://github.com/pytorch/benchmark/issues/2166 Pull Request resolved: https://github.com/pytorch/benchmark/pull/2174 Test Plan: ``` $ python -u run.py -d cuda -t train --bs 4 --metrics None hf_Whisper /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( Running train method from hf_Whisper on cuda in eager mode with input batch size 4 and precision fp32. Traceback (most recent call last): File "/workspace/benchmark/run.py", line 623, in main() # pragma: no cover ^^^^^^ File "/workspace/benchmark/run.py", line 593, in main run_one_step( File "/workspace/benchmark/run.py", line 173, in run_one_step func() File "/workspace/benchmark/torchbenchmark/util/model.py", line 315, in invoke self.train() File "/workspace/benchmark/torchbenchmark/models/hf_Whisper/__init__.py", line 20, in train raise NotImplementedError("Training is not implemented.") NotImplementedError: Training is not implemented. ``` Reviewed By: aaronenyeshi Differential Revision: D54012510 Pulled By: xuzhao9 fbshipit-source-id: bb27bd5adb0bcd778c2c58db7ef5a7b8cc9b2c20 --- torchbenchmark/util/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 6846edf3ff..1c0f7258a0 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -306,7 +306,7 @@ def _invoke_staged_train_test(self, num_batch: int) -> None: return None def invoke(self) -> Optional[Tuple[torch.Tensor]]: - if self.test == "train" and is_staged_train_test(self): + if self.test == "train" and is_staged_train_test(self) and (getattr(self, "train", None) == None): return self._invoke_staged_train_test(num_batch=self.num_batch) assert self.num_batch == 1, "Only staged_train_test supports multiple-batch testing at this time." out = None