From 8e2c99f6d7e171e50f784a7990e0c52e3767cf8d Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 31 Jan 2024 20:24:58 -0800 Subject: [PATCH] Deepcopy model to another device before export to avoid OOM (#118710) Summary: Prior to onnx export, the model is deepcopied to avoid modifications that may affect later performance profiling. However this increases the memory requirement on the device. This PR modifies the script to deepcopy and export the model on another device when possible. X-link: https://github.com/pytorch/pytorch/pull/118710 Approved by: https://github.com/thiagocrepaldi Reviewed By: clee2000 Differential Revision: D53296686 fbshipit-source-id: e764fcf3c4f15f4f8793623571a09fd1b4263898 --- userbenchmark/dynamo/dynamobench/common.py | 39 ++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 831dfe062e..98ea5c6d71 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -1222,6 +1222,31 @@ def __init__( self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx" ) + def _determine_deepcopy_target_device(self): + if current_device == "cpu": + target_device = "cpu" + else: + if torch.cuda.device_count() > 1: + # Copy to another cuda device to avoid OOM. + target_device = "cuda:1" + else: + target_device = "cuda" + return target_device + + def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device): + # Deepcopy model before export to avoid modification to baseline model. + # To avoid OOM, the model is first moved to CPU. Both models are then moved to device. + model_device = next(model.parameters()).device + model.to("cpu") + model_copy = copy.deepcopy(model).to(target_device) + model.to(model_device) + + target_device_example_inputs = tree_map_only( + torch.Tensor, lambda x: x.to(device=target_device), example_inputs + ) + + return model_copy, target_device_example_inputs + @classmethod def _generate_onnx_model_directory( cls, output_directory: str, compiler_name: str, model_name: str @@ -1404,7 +1429,9 @@ def __init__( def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None: if self.copy_before_export: # Deepcopy model before export to avoid modification to baseline model. - model = copy.deepcopy(model) + model, example_inputs = self.deepcopy_model_and_inputs_to_device( + model, example_inputs, self._determine_deepcopy_target_device() + ) # Hack for huggingface models (kwargs only). if isinstance(example_inputs, dict): @@ -1486,7 +1513,9 @@ def _export( ) -> torch.onnx.ONNXProgram: if self.copy_before_export: # Deepcopy model before export to avoid modification to baseline model. - model = copy.deepcopy(model) + model, example_inputs = self.deepcopy_model_and_inputs_to_device( + model, example_inputs, self._determine_deepcopy_target_device() + ) example_args, example_kwargs = _normalize_bench_inputs(example_inputs) options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) @@ -1513,6 +1542,12 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo): def _export( self, model, example_inputs, output_path: str ) -> torch.onnx.ONNXProgram: + if self.copy_before_export: + # Deepcopy model before export to avoid modification to baseline model. + model, example_inputs = self.deepcopy_model_and_inputs_to_device( + model, example_inputs, self._determine_deepcopy_target_device() + ) + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) onnx_program = torch.onnx.dynamo_export(