From 4ec3cfb5b18d755b8f49310dc531d7c9f2167857 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 20 Feb 2024 12:33:53 -0500 Subject: [PATCH 1/2] Test the train function --- torchbenchmark/util/framework/huggingface/model_factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchbenchmark/util/framework/huggingface/model_factory.py b/torchbenchmark/util/framework/huggingface/model_factory.py index b852717f01..dbb7ff4677 100644 --- a/torchbenchmark/util/framework/huggingface/model_factory.py +++ b/torchbenchmark/util/framework/huggingface/model_factory.py @@ -125,6 +125,10 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): self.example_inputs['decoder_input_ids'] = eval_context self.model.eval() self.amp_context = nullcontext + if test == "train" and hasattr(self, "train"): + # users might implement a placeholder train function that throws NotImplementedError + # in this case, try invoking this function to skip the run. + self.train() def get_module(self, wrap_model=True): if not self.is_generate and class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM': From f7bba0a2fa919090c48d501315dec61a5a40c9ee Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 20 Feb 2024 12:36:50 -0500 Subject: [PATCH 2/2] Prioritize the `train()` method over `forward()`. --- torchbenchmark/util/framework/huggingface/model_factory.py | 4 ---- torchbenchmark/util/model.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torchbenchmark/util/framework/huggingface/model_factory.py b/torchbenchmark/util/framework/huggingface/model_factory.py index dbb7ff4677..b852717f01 100644 --- a/torchbenchmark/util/framework/huggingface/model_factory.py +++ b/torchbenchmark/util/framework/huggingface/model_factory.py @@ -125,10 +125,6 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): self.example_inputs['decoder_input_ids'] = eval_context self.model.eval() self.amp_context = nullcontext - if test == "train" and hasattr(self, "train"): - # users might implement a placeholder train function that throws NotImplementedError - # in this case, try invoking this function to skip the run. - self.train() def get_module(self, wrap_model=True): if not self.is_generate and class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM': diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 6846edf3ff..1c0f7258a0 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -306,7 +306,7 @@ def _invoke_staged_train_test(self, num_batch: int) -> None: return None def invoke(self) -> Optional[Tuple[torch.Tensor]]: - if self.test == "train" and is_staged_train_test(self): + if self.test == "train" and is_staged_train_test(self) and (getattr(self, "train", None) == None): return self._invoke_staged_train_test(num_batch=self.num_batch) assert self.num_batch == 1, "Only staged_train_test supports multiple-batch testing at this time." out = None