Skip to content

Commit

Permalink
Add proper iobinding synchronize for ONNX cuda bench (#115773)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#115773
Approved by: https://github.com/thiagocrepaldi
ghstack dependencies: #115670, #115673

Reviewed By: jeanschmidt

Differential Revision: D52244281

fbshipit-source-id: ad9d34e4b8578a9c3799b5d76089651b1e453d7e
  • Loading branch information
BowenBao authored and facebook-github-bot committed Dec 18, 2023
1 parent afe0209 commit b599ae4
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,23 @@ def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
return output_str


@contextlib.contextmanager
def override_synchronize_with_onnx_iobinding(iobinding):
global synchronize
prev_synchrnoize = synchronize
try:
if iobinding is not None:

def new_synchronize():
iobinding.synchronize_inputs()
iobinding.synchronize_outputs()

synchronize = new_synchronize
yield
finally:
synchronize = prev_synchrnoize


def speedup_experiment_onnx(
args,
model_iter_fn,
Expand Down Expand Up @@ -737,7 +754,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
if collect_outputs:
return outputs

return onnxrt_model_iter_fn
return onnxrt_model_iter_fn, iobinding

def create_onnx_fn(onnx_model: OnnxModel, pt_inputs):
# NOTE: Making perf comparison fair by moving out the i/o adapting part.
Expand All @@ -753,18 +770,20 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
def timed_onnx(model, onnx_model: OnnxModel, inputs):
if current_device == "cpu" or onnx_model.is_cpu():
onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs)
iobinding = None
else:
onnxrt_model_iter_fn = create_onnx_input_binded_fn(
onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn(
onnx_model, inputs, expected_output
)
return timed(
model,
onnxrt_model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
)
with override_synchronize_with_onnx_iobinding(iobinding):
return timed(
model,
onnxrt_model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
)

# Insert ONNX warm-up
inputs = (
Expand Down

0 comments on commit b599ae4

Please sign in to comment.