Skip to content

Commit

Permalink
Default to supported model batch size instead of metadata if batch si…
Browse files Browse the repository at this point in the history
…ze not specified

ghstack-source-id: 9650512d1875e64cdbbe53a8259f3d9b9cbe2674
Pull Request resolved: #2015
  • Loading branch information
eellison committed Oct 27, 2023
1 parent edf7115 commit 0d81d78
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,9 @@ def _get_batch_size_from_metadata(self) -> Optional[str]:
return batch_size

def _determine_batch_size(self, user_specified_batch_size=None):
# batch size priority for eval tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > device specified > default
# batch size priority for train tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > default

self.batch_size = user_specified_batch_size

if not self.batch_size:
if not self.batch_size and getattr(self, "ALLOW_CUSTOMIZE_BSIZE", False):
device_specified_batch_size = self._get_batch_size_from_metadata()
self.batch_size = device_specified_batch_size

Expand All @@ -205,7 +202,7 @@ def _determine_batch_size(self, user_specified_batch_size=None):
raise NotImplementedError(f"Model's {'DEFAULT_TRAIN_BSIZE' if self.test == 'train' else 'DEFAULT_EVAL_BSIZE'} is not implemented.")

# Check if specified batch size is supported by the model
if hasattr(self, "ALLOW_CUSTOMIZE_BSIZE") and (not getattr(self, "ALLOW_CUSTOMIZE_BSIZE")):
if not getattr(self, "ALLOW_CUSTOMIZE_BSIZE", True):
if self.test == "train" and (not self.batch_size == self.DEFAULT_TRAIN_BSIZE):
raise NotImplementedError(f"Model doesn't support customizing batch size, but {self.test} test is providing a batch size other than DEFAULT_TRAIN_BSIZE")
elif self.test == "eval" and (not self.batch_size == self.DEFAULT_EVAL_BSIZE):
Expand Down

0 comments on commit 0d81d78

Please sign in to comment.