diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 6846edf3ff..1c0f7258a0 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -306,7 +306,7 @@ def _invoke_staged_train_test(self, num_batch: int) -> None: return None def invoke(self) -> Optional[Tuple[torch.Tensor]]: - if self.test == "train" and is_staged_train_test(self): + if self.test == "train" and is_staged_train_test(self) and (getattr(self, "train", None) == None): return self._invoke_staged_train_test(num_batch=self.num_batch) assert self.num_batch == 1, "Only staged_train_test supports multiple-batch testing at this time." out = None