diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index d3ca51e8f4..cd0b98d4f4 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -703,6 +703,23 @@ def speedup_experiment_ds(args, model_iter_fn, model, example_inputs): return output_str +@contextlib.contextmanager +def override_synchronize_with_onnx_iobinding(iobinding): + global synchronize + prev_synchrnoize = synchronize + try: + if iobinding is not None: + + def new_synchronize(): + iobinding.synchronize_inputs() + iobinding.synchronize_outputs() + + synchronize = new_synchronize + yield + finally: + synchronize = prev_synchrnoize + + def speedup_experiment_onnx( args, model_iter_fn, @@ -737,7 +754,7 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): if collect_outputs: return outputs - return onnxrt_model_iter_fn + return onnxrt_model_iter_fn, iobinding def create_onnx_fn(onnx_model: OnnxModel, pt_inputs): # NOTE: Making perf comparison fair by moving out the i/o adapting part. @@ -753,18 +770,20 @@ def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): def timed_onnx(model, onnx_model: OnnxModel, inputs): if current_device == "cpu" or onnx_model.is_cpu(): onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs) + iobinding = None else: - onnxrt_model_iter_fn = create_onnx_input_binded_fn( + onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn( onnx_model, inputs, expected_output ) - return timed( - model, - onnxrt_model_iter_fn, - inputs, - return_result=True, - times=times, - collect_outputs=args.collect_outputs, - ) + with override_synchronize_with_onnx_iobinding(iobinding): + return timed( + model, + onnxrt_model_iter_fn, + inputs, + return_result=True, + times=times, + collect_outputs=args.collect_outputs, + ) # Insert ONNX warm-up inputs = (