Skip to content

Commit

Permalink
Benchmark onnx export w/ ort fusions (#125700)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#125700
Approved by: https://github.com/thiagocrepaldi

Reviewed By: izaitsevfb

Differential Revision: D57138493

fbshipit-source-id: c5c7b5669b951f20dbbe584543bcef51ef1c7a50
  • Loading branch information
BowenBao authored and facebook-github-bot committed May 9, 2024
1 parent 1708afa commit 6f3faa0
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,40 @@ def _export(
return onnx_program


class OnnxModelFromDynamoAotOptimize(OnnxModelFromDynamo):
"""Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`."""

_COMPILER_NAME = "dynamo_aot_optimize"

def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ONNXProgram:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)

example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
export_output = torch.onnx.dynamo_export(
model, *example_args, **example_kwargs, export_options=options
)

import onnx
from onnxscript.rewriter.onnxruntime import rewrite

model_proto = rewrite(export_output.model_proto)
onnx.save_model(
model_proto,
output_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)

return export_output


class _OnnxPatch:
@classmethod
def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
Expand Down Expand Up @@ -3475,6 +3509,12 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`",
)
group.add_argument(
"--dynamo-onnx-aot-optimize",
"--dynamo_onnx_aot_optimize",
action="store_true",
help="Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`",
)
group.add_argument(
"--backend",
choices=torch._dynamo.list_backends(exclude_tags=None),
Expand Down Expand Up @@ -3839,6 +3879,17 @@ def run(runner, args, original_dir=None):
experiment = speedup_experiment_onnx
output_filename = "dynamo_onnx_aot_inline.csv"
current_onnx_compiler = "dynamo"
elif args.dynamo_onnx_aot_optimize:
optimize_ctx = functools.partial(
optimize_onnx_ctx,
args.output_directory or ".",
OnnxModelFromDynamoAotOptimize,
dynamic_shapes=args.dynamic_shapes,
copy_before_export=args.performance,
)
experiment = speedup_experiment_onnx
output_filename = "dynamo_onnx_aot_optimize.csv"
current_onnx_compiler = "dynamo"
elif args.speedup_dynamo_ts:
optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
experiment = speedup_experiment
Expand Down

0 comments on commit 6f3faa0

Please sign in to comment.