Skip to content

Commit

Permalink
Update AOTI runner util (#116971)
Browse files Browse the repository at this point in the history
Summary:
Update the runner used in integration tests after pytorch/torchrec#1604

X-link: pytorch/pytorch#116971
Approved by: https://github.com/chenyang78

Reviewed By: osalpekar

Differential Revision: D52648895

Pulled By: desertfire

fbshipit-source-id: f1af8c60dcc1527c0db86c53144001f688a3a240
  • Loading branch information
desertfire authored and facebook-github-bot committed Jan 10, 2024
1 parent e7b3e2b commit 6014e57
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

try:
from torch._dynamo.utils import clone_inputs, graph_break_reasons
from torch._inductor.utils import aot_inductor_launcher, fresh_inductor_cache
from torch._inductor.utils import fresh_inductor_cache
except ImportError:
from _dynamo.utils import clone_inputs, graph_break_reasons
from torch._functorch.aot_autograd import set_model_name
Expand Down Expand Up @@ -1114,14 +1114,12 @@ def load(cls, model, example_inputs, device):

so_path = torch._export.aot_compile(model, example_args, example_kwargs)

module = torch.utils.cpp_extension.load_inline(
name="aot_inductor",
cpp_sources=[aot_inductor_launcher(so_path, device)],
functions=["run", "get_call_spec"],
with_cuda=(device == "cuda"),
runner = (
torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
if device == "cpu"
else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1)
)

cls.cache[key] = module
cls.cache[key] = runner

return cls.cache[key]

Expand All @@ -1141,8 +1139,8 @@ def opt_export(_, example_inputs):


def export_aot_inductor(model, example_inputs, device):
module = AOTInductorModelCache.load(model, example_inputs, device)
call_spec = module.get_call_spec()
runner = AOTInductorModelCache.load(model, example_inputs, device)
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])

Expand All @@ -1152,7 +1150,7 @@ def opt_aot_inductor(_, example_inputs, collect_outputs=False):
flat_inputs = fx_pytree.tree_flatten_spec(
(example_args, example_kwargs), in_spec
)
flat_outputs = module.run(flat_inputs)
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)

return opt_aot_inductor
Expand Down

0 comments on commit 6014e57

Please sign in to comment.