Skip to content

Commit

Permalink
Dont try to use default device batch size if we model doesnt support …
Browse files Browse the repository at this point in the history
…customization

[ghstack-poisoned]
  • Loading branch information
eellison committed Oct 26, 2023
1 parent d00b942 commit e940fcc
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,17 @@ def _determine_batch_size(self, batch_size=None):
device_batch_size_key = f"{self.test}_batch_size"
if self.metadata and "devices" in self.metadata and current_device_name in self.metadata["devices"] \
and device_batch_size_key in self.metadata["devices"][current_device_name]:
self.batch_size = self.metadata["devices"][current_device_name][device_batch_size_key]
# only use default device suggestion if we can customize
if getattr(self, "ALLOW_CUSTOMIZE_BSIZE", False):
self.batch_size = self.metadata["devices"][current_device_name][device_batch_size_key]
# If the model doesn't implement test or eval test
# its DEFAULT_TRAIN_BSIZE or DEFAULT_EVAL_BSIZE will still be None
if not self.batch_size:
raise NotImplementedError(f"Test {self.test} is not implemented.")
else:
self.batch_size = batch_size
# Check if specified batch size is supported by the model
if hasattr(self, "ALLOW_CUSTOMIZE_BSIZE") and (not getattr(self, "ALLOW_CUSTOMIZE_BSIZE")):
if not getattr(self, "ALLOW_CUSTOMIZE_BSIZE", True):
if self.test == "train" and (not self.batch_size == self.DEFAULT_TRAIN_BSIZE):
raise NotImplementedError("Model doesn't support customizing batch size.")
elif self.test == "eval" and (not self.batch_size == self.DEFAULT_EVAL_BSIZE):
Expand Down

0 comments on commit e940fcc

Please sign in to comment.