Skip to content

Commit

Permalink
Pass TorchIR to AOTInductor
Browse files Browse the repository at this point in the history
Summary:
Updates `_export.aot_compile` to pass a torch IR graph to inductor, allowing inductor to now run the pre_grad_passes, and reuse more of inductor's code.
Also updates the API to only return the `so_path`, and not returning the exported program. The pytree call spec is now serialized and placed inside of the generated model code. When calling the model, because there is no c++ pytree implementation linked yet, we can access the call specs through `get_call_spec()`, and call pytree flatten/unflattenin python.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan

X-link: pytorch/pytorch#110020

Reviewed By: frank-wei, desertfire

Differential Revision: D49599792

Pulled By: angelayi

fbshipit-source-id: 3c20e4bb60eb8ca02d91772b2e40e9bb72f8f905
  • Loading branch information
angelayi authored and facebook-github-bot committed Oct 26, 2023
1 parent 65c0b7d commit 9cf38ec
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,27 +1175,18 @@ def load(cls, model, example_inputs, device):
example_outputs = model(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

so_path, exported = torch._export.aot_compile(
model, example_args, example_kwargs
)
so_path = torch._export.aot_compile(model, example_args, example_kwargs)

module = torch.utils.cpp_extension.load_inline(
name="aot_inductor",
cpp_sources=[aot_inductor_launcher(so_path, device)],
functions=["run"],
functions=["run", "get_call_spec"],
with_cuda=(device == "cuda"),
)

value = {
"module": module,
"exported": exported,
}
cls.cache[key] = value
cls.cache[key] = module

return (
cls.cache[key]["module"],
cls.cache[key]["exported"],
)
return cls.cache[key]


def export(model, example_inputs):
Expand All @@ -1213,15 +1204,19 @@ def opt_export(_, example_inputs):


def export_aot_inductor(model, example_inputs, device):
module, exported = AOTInductorModelCache.load(model, example_inputs, device)
module = AOTInductorModelCache.load(model, example_inputs, device)
call_spec = module.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])

def opt_aot_inductor(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
flat_example_inputs = fx_pytree.tree_flatten_spec(
(example_args, example_kwargs), exported.call_spec.in_spec

flat_inputs = fx_pytree.tree_flatten_spec(
(example_args, example_kwargs), in_spec
)
output_tensors = module.run(flat_example_inputs)
return pytree.tree_unflatten(output_tensors, exported.call_spec.out_spec)
flat_outputs = module.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)

return opt_aot_inductor

Expand Down

0 comments on commit 9cf38ec

Please sign in to comment.