Skip to content

Commit

Permalink
Prioritize user-defined train() function over the staged forward() (
Browse files Browse the repository at this point in the history
#2174)

Summary:
It gives more user-friendly error messages upon unimplemented train tests.

Fixes #2166

Pull Request resolved: #2174

Test Plan:
```
$ python -u run.py -d cuda -t train --bs 4 --metrics None hf_Whisper
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Running train method from hf_Whisper on cuda in eager mode with input batch size 4 and precision fp32.
Traceback (most recent call last):
  File "/workspace/benchmark/run.py", line 623, in <module>
    main()  # pragma: no cover
    ^^^^^^
  File "/workspace/benchmark/run.py", line 593, in main
    run_one_step(
  File "/workspace/benchmark/run.py", line 173, in run_one_step
    func()
  File "/workspace/benchmark/torchbenchmark/util/model.py", line 315, in invoke
    self.train()
  File "/workspace/benchmark/torchbenchmark/models/hf_Whisper/__init__.py", line 20, in train
    raise NotImplementedError("Training is not implemented.")
NotImplementedError: Training is not implemented.
```

Reviewed By: aaronenyeshi

Differential Revision: D54012510

Pulled By: xuzhao9

fbshipit-source-id: bb27bd5adb0bcd778c2c58db7ef5a7b8cc9b2c20
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Feb 21, 2024
1 parent 29a3fba commit 25338db
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _invoke_staged_train_test(self, num_batch: int) -> None:
return None

def invoke(self) -> Optional[Tuple[torch.Tensor]]:
if self.test == "train" and is_staged_train_test(self):
if self.test == "train" and is_staged_train_test(self) and (getattr(self, "train", None) == None):
return self._invoke_staged_train_test(num_batch=self.num_batch)
assert self.num_batch == 1, "Only staged_train_test supports multiple-batch testing at this time."
out = None
Expand Down

0 comments on commit 25338db

Please sign in to comment.