From 8339eef6ba040c90c12cf6199e43bbb56c624c10 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 20 Oct 2023 00:37:25 -0700 Subject: [PATCH] Benchmark to store test data along exported model (#111095) Summary: X-link: https://github.com/pytorch/pytorch/pull/111095 Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi Reviewed By: jeanschmidt Differential Revision: D50475063 fbshipit-source-id: 13a53e32d32bf66801c68d38d4f1f97f3b1daaab --- userbenchmark/dynamo/dynamobench/common.py | 303 +++++++++++++-------- 1 file changed, 190 insertions(+), 113 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 81592709ac..a5654df0b2 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 from __future__ import annotations +import abc + import argparse import collections import contextlib @@ -29,6 +31,7 @@ Mapping, NamedTuple, Optional, + Sequence, Tuple, Type, TYPE_CHECKING, @@ -963,8 +966,13 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): return onnxrt_model_iter_fn def create_onnx_fn(onnx_model: OnnxModelFromTorchScript, pt_inputs): + # NOTE: Making perf comparison fair by moving out the i/o adapting part. + # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here. + # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement. + onnx_inputs = onnx_model.adapt_pt_inputs_to_onnx(pt_inputs) + def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): - return onnx_model.run(pt_inputs) + return onnx_model.run_with_onnx_inputs(onnx_inputs) return onnxrt_model_iter_fn @@ -1254,78 +1262,59 @@ def wrapper(self, *args, **kwargs) -> Any: return wrapper -class OnnxModelFromTorchScript: - """TorchScript based onnx export. `torch.onnx.export` - - TODO(bowbao): - * large model export failed. - Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is - smaller than 2GB. - * OOM on slightly larger model. - Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to - cuda:1, however ORT perf drop significantly. - For now running everything with batch_size 1 set in launch script. - """ - - TORCH_TO_NUMPY_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.longlong, - torch.bool: np.bool_, - } +class OnnxModel(abc.ABC): + _COMPILER_NAME: str def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool): - assert not dynamic_shapes, "NYI dynamic shapes for OnnxModelFromTorchScript" - self.model_path = self._generate_onnx_model_path(output_directory) - self._export( - model, - example_inputs, - self.model_path, - opset_version=17, - do_constant_folding=False, - verbose=False, - ) - self.onnx_session = self._init_ort_session(self.model_path) - - def _generate_onnx_model_path( - self, output_directory: str, onnx_model_folder_name: str = "bench_onnx_models" - ) -> str: # Hack to get model name. from torch._functorch import aot_autograd model_name = aot_autograd.model_name - model_path = pathlib.Path(output_directory, onnx_model_folder_name, model_name) + self.model_dir = self._generate_onnx_model_directory( + output_directory, self._COMPILER_NAME, model_name + ) + self.model_path = str( + self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx" + ) + + @classmethod + def _generate_onnx_model_directory( + cls, output_directory: str, compiler_name: str, model_name: str + ) -> pathlib.Path: + model_path = pathlib.Path( + output_directory, + ".onnx_models", + model_name, + compiler_name, + ) if model_path.exists() and model_path.is_dir(): shutil.rmtree(model_path) model_path.mkdir(parents=True, exist_ok=True) - return str(model_path / "model.onnx") + return model_path - def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None: - # Hack for huggingface models (kwargs only). - if isinstance(example_inputs, dict): + @abc.abstractmethod + def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: + ... - class WrapperModel(torch.nn.Module): - def __init__(self, model, keys): - super().__init__() - self.model = model - self.keys = keys - - def forward(self, *args): - return self.model(**dict(zip(self.keys, args))) + @abc.abstractmethod + def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: + ... - model = WrapperModel(model, list(example_inputs.keys())) + def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]: + pt_inputs = self.format_pt_inputs(pt_inputs) + return { + ort_input.name: pt_input.cpu().numpy() + for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs) + } - torch.onnx.export( - model, - self.format_pt_inputs(example_inputs), - output_path, - **kwargs, - ) + def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any: + pt_outputs = [ + torch.from_numpy(onnx_output).to(current_device) + for onnx_output in onnx_outputs + ] + if len(pt_outputs) == 1: + return pt_outputs[0] + return pt_outputs def _init_ort_session(self, model_path: str): import onnxruntime @@ -1356,40 +1345,6 @@ def cpu(self) -> Self: self.onnx_session.set_providers(["CPUExecutionProvider"]) return self - def format_pt_inputs(self, pt_inputs): - # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary, - # and consumed like `model(**pt_inputs)`. - # For other benchmarks, pt_inputs are formatted as tuple and consumed - # like `model(*pt_inputs)`. - if isinstance(pt_inputs, dict): - pt_inputs = list(pt_inputs.values()) - if isinstance(pt_inputs, torch.Tensor): - pt_inputs = (pt_inputs,) - return tuple(arg.contiguous() for arg in pt_inputs) - - def format_pt_outputs(self, pt_outputs): - if isinstance(pt_outputs, torch.Tensor): - pt_outputs = (pt_outputs,) - - pt_outputs, _ = pytree.tree_flatten(pt_outputs) - - # Hack for huggingface model outputs - try: - from transformers import modeling_outputs - except ImportError: - pass - else: - - def _to_tuple(x): - if isinstance(x, modeling_outputs.ModelOutput): - return x.to_tuple() - return x - - pt_outputs = pytree.tree_map(_to_tuple, pt_outputs) - pt_outputs, _ = pytree.tree_flatten(pt_outputs) - - return pt_outputs - def create_outputs(self, *example_outputs): return tuple(torch.empty_like(x) for x in example_outputs) @@ -1433,34 +1388,145 @@ def run_with_iobinding(self, iobinding, outputs): self.onnx_session.run_with_iobinding(iobinding) return outputs + def run_with_onnx_inputs(self, onnx_inputs): + return self.onnx_session.run(None, onnx_inputs) + + @classmethod + def save_tensor_data(cls, numpy_tensor, output_path): + from onnx import numpy_helper + + proto_tensor = numpy_helper.from_array(numpy_tensor) + with open(output_path, "wb") as f: + f.write(proto_tensor.SerializeToString()) + + def run_and_serialize_inputs_outputs(self, pt_inputs): + onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs) + onnx_outputs = self.run_with_onnx_inputs(onnx_inputs) + + test_data_dir = self.model_dir / "test_data_set_0" + test_data_dir.mkdir(parents=True, exist_ok=True) + + for i, onnx_input in enumerate(onnx_inputs.values()): + self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb")) + for i, onnx_output in enumerate(onnx_outputs): + self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb")) + + return self.adapt_onnx_outputs_to_pt(onnx_outputs) + def run(self, pt_inputs): # NOTE: For CUDA performance testing, use `run_with_iobinding` to exclude memory # copying overhead for inputs/outputs between cpu and gpu. # Otherwise perf number is inaccurate. - pt_inputs = self.format_pt_inputs(pt_inputs) - onnx_inputs = { - ort_input.name: pt_input.cpu().numpy() - for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs) - } - ort_outputs = self.onnx_session.run(None, onnx_inputs) - pt_outputs = [ - torch.from_numpy(ort_output).to(current_device) - for ort_output in ort_outputs - ] - if len(pt_outputs) == 1: - return pt_outputs[0] + onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs) + onnx_outputs = self.run_with_onnx_inputs(onnx_inputs) + return self.adapt_onnx_outputs_to_pt(onnx_outputs) + + +class OnnxModelFromTorchScript(OnnxModel): + """TorchScript based onnx export. `torch.onnx.export` + + TODO(bowbao): + * large model export failed. + Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is + smaller than 2GB. + * OOM on slightly larger model. + Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to + cuda:1, however ORT perf drop significantly. + For now running everything with batch_size 1 set in launch script. + """ + + TORCH_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.longlong, + torch.bool: np.bool_, + } + + _COMPILER_NAME = "torchscript" + + def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool): + if dynamic_shapes: + raise NotImplementedError("NYI dynamic shapes for OnnxModelFromTorchScript") + super().__init__(output_directory, model, example_inputs, dynamic_shapes) + self._export( + model, + example_inputs, + self.model_path, + opset_version=17, + do_constant_folding=False, + verbose=False, + ) + self.onnx_session = self._init_ort_session(self.model_path) + + def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None: + # Hack for huggingface models (kwargs only). + if isinstance(example_inputs, dict): + + class WrapperModel(torch.nn.Module): + def __init__(self, model, keys): + super().__init__() + self.model = model + self.keys = keys + + def forward(self, *args): + return self.model(**dict(zip(self.keys, args))) + + model = WrapperModel(model, list(example_inputs.keys())) + + torch.onnx.export( + model, + self.format_pt_inputs(example_inputs), + output_path, + **kwargs, + ) + + def format_pt_inputs(self, pt_inputs): + # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary, + # and consumed like `model(**pt_inputs)`. + # For other benchmarks, pt_inputs are formatted as tuple and consumed + # like `model(*pt_inputs)`. + if isinstance(pt_inputs, dict): + pt_inputs = list(pt_inputs.values()) + if isinstance(pt_inputs, torch.Tensor): + pt_inputs = (pt_inputs,) + return tuple(arg.contiguous() for arg in pt_inputs) + + def format_pt_outputs(self, pt_outputs): + if isinstance(pt_outputs, torch.Tensor): + pt_outputs = (pt_outputs,) + + pt_outputs, _ = pytree.tree_flatten(pt_outputs) + + # Hack for huggingface model outputs + try: + from transformers import modeling_outputs + except ImportError: + pass + else: + + def _to_tuple(x): + if isinstance(x, modeling_outputs.ModelOutput): + return x.to_tuple() + return x + + pt_outputs = pytree.tree_map(_to_tuple, pt_outputs) + pt_outputs, _ = pytree.tree_flatten(pt_outputs) + return pt_outputs -class OnnxModelFromDynamo(OnnxModelFromTorchScript): +class OnnxModelFromDynamo(OnnxModel): """Dynamo and Fx based export. `torch.onnx.dynamo_export`.""" - _EXPORTED_MODEL_FOLDER_NAME = "bench_dynamo_onnx_model" + _COMPILER_NAME = "dynamo" def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool): - self.model_path = self._generate_onnx_model_path( - output_directory, self._EXPORTED_MODEL_FOLDER_NAME - ) + super().__init__(output_directory, model, example_inputs, dynamic_shapes) self._dynamic_shapes = dynamic_shapes self._export_output = self._export(model, example_inputs, self.model_path) self.onnx_session = self._init_ort_session(self.model_path) @@ -1488,7 +1554,7 @@ def format_pt_outputs(self, pt_outputs): class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo): """Dynamo and Fx based export, with AOT inline post export. `torch.onnx.dynamo_export`.""" - _EXPORTED_MODEL_FOLDER_NAME = "bench_dynamo_onnx_aot_inline_model" + _COMPILER_NAME = "dynamo_aot_inline" def _export( self, model, example_inputs, output_path: str @@ -1661,6 +1727,7 @@ def optimize_onnx_ctx( # 2. Create iobinding for ORT. # 3. Run ORT for n iterations. onnx_model: Optional[OnnxModelFromTorchScript] = None + test_data_dumped = False def run_n_iterations_onnx(model, inputs, n=2): from torch.onnx._internal import exporter @@ -1686,7 +1753,14 @@ def run_n_iterations_onnx(model, inputs, n=2): for _ in range(n): try: - outputs = onnx_model.run(inputs) + nonlocal test_data_dumped + if not test_data_dumped: + # Serializes inputs and outputs to .pb files for further offline analysis. + # Due to this, this function is not and should not be used for perf measurement. + outputs = onnx_model.run_and_serialize_inputs_outputs(inputs) + test_data_dumped = True + else: + outputs = onnx_model.run(inputs) except Exception as e: err_msg = str(e) oom_msgs = ( @@ -2354,6 +2428,9 @@ def record_status(accuracy_status, dynamo_start_stats): ) = _OnnxPatch.patch_non_tensor_outputs( correct_result, new_result, fp64_outputs ) + # TODO: store correct_result into the dumped file for offline onnx model validation. + # The downside and potential problem, is that the output formats may be different. + # E.g., the output order might not match, None might be part of output, etc. try: if not same(