From 3f1c3ebe751ae2f8f5777565b28a8968d5945ddc Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 26 Sep 2023 16:33:36 -0700 Subject: [PATCH] Fix PyTorch CI HUD dashboard missing perf numbers: hf_Whisper (#1935) Summary: A few models were passing accuracy check, but surprisingly failing the perf run, resulting in dashboard entries like: image Reproing the hud's commands locally, ``` # pass python benchmarks/dynamo/torchbench.py --accuracy --no-translation-validation --training --amp --backend inductor --disable-cudagraphs --device cuda --total-partitions 4 --partition-id 1 --output hf_Whisper_accuracy.csv --only hf_Whisper # fail (on https://github.com/pytorch/benchmark/blob/4ea3bba3b8010f5d4a629bb8f530a92570f34518/torchbenchmark/util/model.py#L195C48-L195C48) python benchmarks/dynamo/torchbench.py --performance --cold-start-latency --training --amp --backend inductor --disable-cudagraphs --device cuda --total-partitions 4 --partition-id 1 --output hf_Whisper_perf.csv --only hf_Whisper ``` The error suggests that hf_Whisper does not provide a batch size for the training mode perf run. Summarizing discussion with xuzhao9: > I think we could: > 1. set a default train batch size for hf_Whisper, if you still want to test forward/backward pass without a defined train test > 2. in model.py, make sure self.batch_size is not None (before accuracy check overrides batch size to 4) I implement 1, we set default batch sizes in the parent class of all benchmark models, with ability to be overwritten by individual models. Pull Request resolved: https://github.com/pytorch/benchmark/pull/1935 Reviewed By: xuzhao9 Differential Revision: D49641235 Pulled By: xmfan fbshipit-source-id: 2f93fb742846d7c34936cbbc8e8d3e22c5a76662 --- torchbenchmark/models/hf_Whisper/__init__.py | 1 + .../models/hf_Whisper/metadata.yaml | 1 + torchbenchmark/util/model.py | 57 +++++++++++-------- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/torchbenchmark/models/hf_Whisper/__init__.py b/torchbenchmark/models/hf_Whisper/__init__.py index a03b45a55c..b77d9ae5d6 100644 --- a/torchbenchmark/models/hf_Whisper/__init__.py +++ b/torchbenchmark/models/hf_Whisper/__init__.py @@ -4,6 +4,7 @@ class Model(HuggingFaceModel): task = SPEECH.RECOGNITION + DEFAULT_TRAIN_BSIZE = 8 DEFAULT_EVAL_BSIZE = 8 DEFAULT_EVAL_CUDA_PRECISION = "fp16" diff --git a/torchbenchmark/models/hf_Whisper/metadata.yaml b/torchbenchmark/models/hf_Whisper/metadata.yaml index 1cd8c54a52..743608a788 100644 --- a/torchbenchmark/models/hf_Whisper/metadata.yaml +++ b/torchbenchmark/models/hf_Whisper/metadata.yaml @@ -6,5 +6,6 @@ eval_deterministic: false eval_nograd: true not_implemented: - device: cpu +- test: train train_benchmark: false train_deterministic: false \ No newline at end of file diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index c85f466539..a842e15de2 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -171,36 +171,45 @@ def _determine_dynamic_num_batches(self, user_specified_num_batches: Optional[in assert hasattr(self, 'DEFAULT_NUM_BATCH'), f"We expect all models with dynamic shapes specify field `DEFAULT_NUM_BATCHES`." return self.DEFAULT_NUM_BATCH - def _determine_batch_size(self, batch_size=None): + def _get_batch_size_from_metadata(self) -> Optional[str]: + if self.device != "cuda": + current_device_name = str(self.device) + else: + current_device_name = torch.cuda.get_device_name() + assert current_device_name, f"torch.cuda.get_device_name() returns None when device is set to cuda, please double check." + if current_device_name in SPECIAL_DEVICE_MAPPING: + current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name] + + # use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size + device_batch_size_key = f"{self.test}_batch_size" + if self.metadata and "devices" in self.metadata and current_device_name in self.metadata["devices"] \ + and device_batch_size_key in self.metadata["devices"][current_device_name]: + batch_size = self.metadata["devices"][current_device_name][device_batch_size_key] + return batch_size + + def _determine_batch_size(self, user_specified_batch_size=None): # batch size priority for eval tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > device specified > default # batch size priority for train tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > default - self.batch_size = batch_size - if not batch_size: - self.batch_size = self.DEFAULT_TRAIN_BSIZE if self.test == "train" else self.DEFAULT_EVAL_BSIZE - if self.device == "cuda": - current_device_name = torch.cuda.get_device_name() - assert current_device_name, f"torch.cuda.get_device_name() returns None when device is set to cuda, please double check." - if current_device_name in SPECIAL_DEVICE_MAPPING: - current_device_name = SPECIAL_DEVICE_MAPPING[current_device_name] - else: - current_device_name = str(self.device) - # use the device suggestion on CUDA inference tests, key should be either eval_batch_size or train_batch_size - device_batch_size_key = f"{self.test}_batch_size" - if self.metadata and "devices" in self.metadata and current_device_name in self.metadata["devices"] \ - and device_batch_size_key in self.metadata["devices"][current_device_name]: - self.batch_size = self.metadata["devices"][current_device_name][device_batch_size_key] - # If the model doesn't implement test or eval test - # its DEFAULT_TRAIN_BSIZE or DEFAULT_EVAL_BSIZE will still be None - if not self.batch_size: - raise NotImplementedError(f"Test {self.test} is not implemented.") - else: - self.batch_size = batch_size + + self.batch_size = user_specified_batch_size + + if not self.batch_size: + device_specified_batch_size = self._get_batch_size_from_metadata() + self.batch_size = device_specified_batch_size + + if not self.batch_size: + default_batch_size = self.DEFAULT_TRAIN_BSIZE if self.test == "train" else self.DEFAULT_EVAL_BSIZE + self.batch_size = default_batch_size + + if not self.batch_size: + raise NotImplementedError(f"Model's {'DEFAULT_TRAIN_BSIZE' if self.test == 'train' else 'DEFAULT_EVAL_BSIZE'} is not implemented.") + # Check if specified batch size is supported by the model if hasattr(self, "ALLOW_CUSTOMIZE_BSIZE") and (not getattr(self, "ALLOW_CUSTOMIZE_BSIZE")): if self.test == "train" and (not self.batch_size == self.DEFAULT_TRAIN_BSIZE): - raise NotImplementedError("Model doesn't support customizing batch size.") + raise NotImplementedError(f"Model doesn't support customizing batch size, but {self.test} test is providing a batch size other than DEFAULT_TRAIN_BSIZE") elif self.test == "eval" and (not self.batch_size == self.DEFAULT_EVAL_BSIZE): - raise NotImplementedError("Model doesn't support customizing batch size.") + raise NotImplementedError(f"Model doesn't support customizing batch size, but {self.test} test is providing a batch size other than DEFAULT_EVAL_BSIZE") elif self.dargs.accuracy: self.batch_size = 4 if self.batch_size > 4 else self.batch_size