Skip to content

Commit

Permalink
Add torchao to PT2 Benchmark Runner
Browse files Browse the repository at this point in the history
Summary:
X-link: #2268

Support torchao performance and accuracy tests in PT2 Benchmark Runner, using the inductor backend as the baseline.

X-link: pytorch/pytorch#126469

Reviewed By: jerryzh168

Differential Revision: D57463273

Pulled By: xuzhao9

fbshipit-source-id: 64520f18b63107ce5f07447ef7f4a8c841d9ff1f
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed May 20, 2024
1 parent ebbc77b commit b7524a6
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,18 @@ def get_example_inputs(self):
action="store_true",
help="Measure speedup with TorchInductor",
)
group.add_argument(
"--quantization",
choices=[
"int8dynamic",
"int8weightonly",
"int4weightonly",
"autoquant",
"noquant",
],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
)
group.add_argument(
"--export",
action="store_true",
Expand Down Expand Up @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None):
if args.inductor:
assert args.backend is None
args.backend = "inductor"
if args.quantization:
assert args.backend is None
args.backend = "torchao"
if args.dynamic_batch_only:
args.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = True
Expand Down Expand Up @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None):

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
elif args.backend == "torchao":
assert "cuda" in args.devices, "Quantization requires CUDA device."
assert args.bfloat16, "Quantization requires dtype bfloat16."
from .torchao import setup_baseline, torchao_optimize_ctx

setup_baseline()
baseline_ctx = functools.partial(
torch.compile,
backend="inductor",
fullgraph=args.nopython,
mode=args.inductor_compile_mode,
)
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
optimize_ctx = torchao_optimize_ctx(args.quantization)
else:
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
experiment = speedup_experiment
Expand Down

0 comments on commit b7524a6

Please sign in to comment.