Skip to content

Commit

Permalink
Rename torch.onnx.ExportOutput* to ONNXProgram* (#112263)
Browse files Browse the repository at this point in the history
Summary:
Since PyTorch 2.1, torch.export API was introduced and the term "export"
got overloaded due to the already existing torch.onnx.export API.

The torch.onnx.dynamo_export API was introduced on pyTorch 2.0 and it
exposed a torch.onnx.ExportOutput which now can be confused with
torch.export.export output

To prevent such ambiguity and standardize names around the new
torch.export.ExportedProgram, this PR renames torch.onnx.ExportOutput to
torch.onnx.ONNXProgram

X-link: pytorch/pytorch#112263
Approved by: https://github.com/BowenBao
ghstack dependencies: #112444

Reviewed By: PaliC

Differential Revision: D51057229

fbshipit-source-id: f43c1fa8d1820ad69df61ac9f8f84d5ec3995fbe
  • Loading branch information
Thiago Crepaldi authored and facebook-github-bot committed Nov 7, 2023
1 parent 1e03c23 commit f5b502a
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,31 +1527,31 @@ class OnnxModelFromDynamo(OnnxModel):
def __init__(self, output_directory, model, example_inputs, dynamic_shapes: bool):
super().__init__(output_directory, model, example_inputs, dynamic_shapes)
self._dynamic_shapes = dynamic_shapes
self._export_output = self._export(model, example_inputs, self.model_path)
self._onnx_program = self._export(model, example_inputs, self.model_path)
# Clear the model proto to save memory.
# The model proto is saved to disk and no longer needed from `export_output`.
# `export_output` is kept for i/o adapter usage.
self._export_output.model_proto.Clear()
# The model proto is saved to disk and no longer needed from `onnx_program`.
# `onnx_program` is kept for i/o adapter usage.
self._onnx_program.model_proto.Clear()
self.onnx_session = self._init_ort_session(self.model_path)

def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ExportOutput:
) -> torch.onnx.ONNXProgram:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
export_output = torch.onnx.dynamo_export(
onnx_program = torch.onnx.dynamo_export(
model, *example_args, **example_kwargs, export_options=options
)

export_output.save(output_path)
return export_output
onnx_program.save(output_path)
return onnx_program

def format_pt_inputs(self, pt_inputs):
pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs)
return self._export_output.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)
return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)

def format_pt_outputs(self, pt_outputs):
return self._export_output.adapt_torch_outputs_to_onnx(pt_outputs)
return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs)


class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
Expand All @@ -1561,10 +1561,10 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):

def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ExportOutput:
) -> torch.onnx.ONNXProgram:
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
export_output = torch.onnx.dynamo_export(
onnx_program = torch.onnx.dynamo_export(
model, *example_args, **example_kwargs, export_options=options
)
# Apply AOT inline post export.
Expand All @@ -1575,12 +1575,12 @@ def _export(
# Workaround for inliner not supporting with models larger than 2GB.
# Save model to disk first separating out external data,
# and load back without external data for inliner to work on.
model_proto = export_output.model_proto
model_proto = onnx_program.model_proto
onnx.save_model(model_proto, output_path, save_as_external_data=True)
model_proto = onnx.load(output_path, load_external_data=False)
model_proto = onnx.inliner.inline_local_functions(model_proto)
onnx.save_model(model_proto, output_path)
return export_output
return onnx_program


class _OnnxPatch:
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def run_n_iterations_onnx(model, inputs, n=2):
return outputs
except exporter.OnnxExporterError as e:
# `torch.onnx.dynamo_export` raises error that encloses diagnostics.
diagnostic_context = e.export_output.diagnostic_context
diagnostic_context = e.onnx_program.diagnostic_context
for parsed_error in parser.parse_diagnostic_context(diagnostic_context):
output_csv(
output_error_filename, parsed_error.headers, parsed_error.row
Expand Down

0 comments on commit f5b502a

Please sign in to comment.