diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 2dd5736f55..d610d7dd13 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -2578,6 +2578,14 @@ def record_status(accuracy_status, dynamo_start_stats): # E.g., the output order might not match, None might be part of output, etc. try: + if self.args.training and self.args.amp: + if process_fn := self.get_output_amp_train_process_func.get( + name, None + ): + correct_result = process_fn(correct_result) + new_result = process_fn(new_result) + fp64_outputs = process_fn(fp64_outputs) + if not same( correct_result, new_result, diff --git a/userbenchmark/dynamo/dynamobench/torchbench.py b/userbenchmark/dynamo/dynamobench/torchbench.py index 2b434b66e5..a6b4edb3a4 100755 --- a/userbenchmark/dynamo/dynamobench/torchbench.py +++ b/userbenchmark/dynamo/dynamobench/torchbench.py @@ -88,6 +88,30 @@ def maybe_list_to_set(obj): return maybe_list_to_set(data) +def process_hf_reformer_output(out): + assert isinstance(out, list) + # second output is unstable + return [elem for i, elem in enumerate(out) if i != 1] + + +def process_hf_whisper_output(out): + out_ret = [] + for i, elem in enumerate(out): + if i == 0: + assert isinstance(elem, dict) + out_ret.append({k: v for k, v in elem.items() if k != "logits"}) + elif i != 1: + out_ret.append(elem) + + return out_ret + + +process_train_model_output = { + "hf_Reformer": process_hf_reformer_output, + "hf_Whisper": process_hf_whisper_output, +} + + class TorchBenchmarkRunner(BenchmarkRunner): def __init__(self): super().__init__() @@ -142,6 +166,10 @@ def very_slow_models(self): def non_deterministic_models(self): return self._config["non_deterministic"] + @property + def get_output_amp_train_process_func(self): + return process_train_model_output + @property def skip_not_suitable_for_training_models(self): return self._skip["test"]["training"]