Skip to content

Commit

Permalink
Align xpu models batch size with A100 (#2378)
Browse files Browse the repository at this point in the history
Summary:
To align xpu batchsize for dynamobenchmark torchbench suite

Pull Request resolved: #2378

Reviewed By: aaronenyeshi

Differential Revision: D59961717

Pulled By: xuzhao9

fbshipit-source-id: c926d6d14d8b979284aa465738132c17425dafe9
  • Loading branch information
chuanqi129 authored and facebook-github-bot committed Jul 22, 2024
1 parent 11cf319 commit 03cde49
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from torchbenchmark.util.input import input_cast, ModelInputDescriptor

SPECIAL_DEVICE_MAPPING = {"AMD Instinct MI210": "NVIDIA A100-SXM4-40GB"}
SPECIAL_DEVICE_MAPPING = {"AMD Instinct MI210": "NVIDIA A100-SXM4-40GB", "Intel(R) Data Center GPU Max 1100": "NVIDIA A100-SXM4-40GB", "Intel(R) Data Center GPU Max 1550": "NVIDIA A100-SXM4-40GB"}


class PostInitProcessor(type):
Expand Down Expand Up @@ -211,16 +211,24 @@ def _determine_dynamic_num_batches(
return 1

def _get_batch_size_from_metadata(self) -> Optional[str]:
if self.device != "cuda":
current_device_name = str(self.device)
else:
if self.device == "cuda":
current_device_name = (
torch.cuda.get_device_name()
if torch.cuda.get_device_name()
else "UNKNOWN"
)
if current_device_name in SPECIAL_DEVICE_MAPPING:
current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name]
elif self.device == "xpu":
current_device_name = (
torch.xpu.get_device_name()
if torch.xpu.get_device_name()
else "UNKNOWN"
)
if current_device_name in SPECIAL_DEVICE_MAPPING:
current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name]
else:
current_device_name = str(self.device)

# use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size
device_batch_size_key = f"{self.test}_batch_size"
Expand Down

0 comments on commit 03cde49

Please sign in to comment.