Skip to content

Commit

Permalink
Deepcopy model to another device before export to avoid OOM (#118710)
Browse files Browse the repository at this point in the history
Summary:
Prior to onnx export, the model is deepcopied to avoid modifications that may affect later performance profiling. However this increases the memory requirement on the device.
This PR modifies the script to deepcopy and export the model on another device when possible.

X-link: pytorch/pytorch#118710
Approved by: https://github.com/thiagocrepaldi

Reviewed By: clee2000

Differential Revision: D53296686

fbshipit-source-id: e764fcf3c4f15f4f8793623571a09fd1b4263898
  • Loading branch information
BowenBao authored and facebook-github-bot committed Feb 1, 2024
1 parent e60829a commit 8e2c99f
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,31 @@ def __init__(
self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
)

def _determine_deepcopy_target_device(self):
if current_device == "cpu":
target_device = "cpu"
else:
if torch.cuda.device_count() > 1:
# Copy to another cuda device to avoid OOM.
target_device = "cuda:1"
else:
target_device = "cuda"
return target_device

def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
# Deepcopy model before export to avoid modification to baseline model.
# To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
model_device = next(model.parameters()).device
model.to("cpu")
model_copy = copy.deepcopy(model).to(target_device)
model.to(model_device)

target_device_example_inputs = tree_map_only(
torch.Tensor, lambda x: x.to(device=target_device), example_inputs
)

return model_copy, target_device_example_inputs

@classmethod
def _generate_onnx_model_directory(
cls, output_directory: str, compiler_name: str, model_name: str
Expand Down Expand Up @@ -1404,7 +1429,9 @@ def __init__(
def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model = copy.deepcopy(model)
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)

# Hack for huggingface models (kwargs only).
if isinstance(example_inputs, dict):
Expand Down Expand Up @@ -1486,7 +1513,9 @@ def _export(
) -> torch.onnx.ONNXProgram:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model = copy.deepcopy(model)
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)

example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
Expand All @@ -1513,6 +1542,12 @@ class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
def _export(
self, model, example_inputs, output_path: str
) -> torch.onnx.ONNXProgram:
if self.copy_before_export:
# Deepcopy model before export to avoid modification to baseline model.
model, example_inputs = self.deepcopy_model_and_inputs_to_device(
model, example_inputs, self._determine_deepcopy_target_device()
)

example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
onnx_program = torch.onnx.dynamo_export(
Expand Down

0 comments on commit 8e2c99f

Please sign in to comment.