diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 5bfff714a6..a4a6f7137c 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -188,12 +188,9 @@ def _get_batch_size_from_metadata(self) -> Optional[str]: return batch_size def _determine_batch_size(self, user_specified_batch_size=None): - # batch size priority for eval tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > device specified > default - # batch size priority for train tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > default - self.batch_size = user_specified_batch_size - if not self.batch_size: + if not self.batch_size and getattr(self, "ALLOW_CUSTOMIZE_BSIZE", False): device_specified_batch_size = self._get_batch_size_from_metadata() self.batch_size = device_specified_batch_size @@ -205,7 +202,7 @@ def _determine_batch_size(self, user_specified_batch_size=None): raise NotImplementedError(f"Model's {'DEFAULT_TRAIN_BSIZE' if self.test == 'train' else 'DEFAULT_EVAL_BSIZE'} is not implemented.") # Check if specified batch size is supported by the model - if hasattr(self, "ALLOW_CUSTOMIZE_BSIZE") and (not getattr(self, "ALLOW_CUSTOMIZE_BSIZE")): + if not getattr(self, "ALLOW_CUSTOMIZE_BSIZE", True): if self.test == "train" and (not self.batch_size == self.DEFAULT_TRAIN_BSIZE): raise NotImplementedError(f"Model doesn't support customizing batch size, but {self.test} test is providing a batch size other than DEFAULT_TRAIN_BSIZE") elif self.test == "eval" and (not self.batch_size == self.DEFAULT_EVAL_BSIZE):