Skip to content

Commit

Permalink
Update test utility to use AOTIModelRunner
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#111657

Use AOTIModelRunner provided by libtorch instead of the custom written RAIIModelContainer for testing. This change also makes running AOTInductor benchmarks on CPU possbile.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

imported-using-ghimport

Reviewed By: angelayi

Differential Revision: D50560764

Pulled By: desertfire

fbshipit-source-id: dbcd6c029bb0a36596bb0de894c7cffc20f0aae0
  • Loading branch information
desertfire authored and facebook-github-bot committed Oct 23, 2023
1 parent b5631ed commit 2f9b20e
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,9 @@ def maybe_mark_profile(*args, **kwargs):

with maybe_profile(args.export_profiler_trace) as p:
if args.export_aot_inductor:
frozen_model_iter_fn = export_aot_inductor(model, example_inputs)
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.devices[0]
)
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)

Expand Down Expand Up @@ -1165,7 +1167,7 @@ class AOTInductorModelCache:
cache = dict()

@classmethod
def load(cls, model, example_inputs):
def load(cls, model, example_inputs, device):
key = weakref.ref(model)
if key not in cls.cache:
# Register the output dataclass to pytree
Expand All @@ -1179,10 +1181,9 @@ def load(cls, model, example_inputs):

module = torch.utils.cpp_extension.load_inline(
name="aot_inductor",
cpp_sources=[aot_inductor_launcher],
cpp_sources=[aot_inductor_launcher(so_path, device)],
functions=["run"],
extra_ldflags=[so_path],
with_cuda=True,
with_cuda=(device == "cuda"),
)

value = {
Expand Down Expand Up @@ -1211,8 +1212,8 @@ def opt_export(_, example_inputs):
return opt_export


def export_aot_inductor(model, example_inputs):
module, exported = AOTInductorModelCache.load(model, example_inputs)
def export_aot_inductor(model, example_inputs, device):
module, exported = AOTInductorModelCache.load(model, example_inputs, device)

def opt_aot_inductor(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
Expand Down Expand Up @@ -3596,8 +3597,9 @@ def run(runner, args, original_dir=None):
elif args.backend or args.export_aot_inductor:
if args.export_aot_inductor:
assert not args.training, "AOTInductor only supports inference"
assert args.devices == ["cuda"], "AOTInductor only tested for CUDA"
optimize_ctx = export_aot_inductor
optimize_ctx = functools.partial(
export_aot_inductor, device=args.devices[0]
)

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
Expand Down

0 comments on commit 2f9b20e

Please sign in to comment.