From 58a632331d3dd102d0ffa27184b428b7a8bf6172 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 16 May 2024 15:58:31 -0700 Subject: [PATCH] Add torchao to PT2 Benchmark Runner (#2268) Summary: X-link: https://github.com/pytorch/pytorch/pull/126469 Support torchao performance and accuracy tests in PT2 Benchmark Runner, using the inductor backend as the baseline. Reviewed By: jerryzh168 Differential Revision: D57463273 --- userbenchmark/dynamo/dynamobench/common.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 096dbc48ec..358a2a8ded 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -3467,6 +3467,12 @@ def get_example_inputs(self): action="store_true", help="Measure speedup with TorchInductor", ) + group.add_argument( + "--quantization", + choices=["int8dynamic", "int8weightonly", "int4weightonly", "autoquant", "noquant"], + default=None, + help="Measure speedup of torchao quantization with TorchInductor baseline", + ) group.add_argument( "--export", action="store_true", @@ -3661,6 +3667,9 @@ def run(runner, args, original_dir=None): if args.inductor: assert args.backend is None args.backend = "inductor" + if args.quantization: + assert args.backend is None + args.backend = "torchao" if args.dynamic_batch_only: args.dynamic_shapes = True torch._dynamo.config.assume_static_by_default = True @@ -3939,6 +3948,18 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) + elif args.backend == "torchao": + assert "cuda" in args.devices, "Quantization requires CUDA device." + assert args.bfloat16, "Quantization requires dtype bfloat16." + from .torchao import torchao_optimize_ctx + baseline_ctx = functools.partial( + torch.compile, + backend="inductor", + fullgraph=args.nopython, + mode=args.inductor_compile_mode, + ) + runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) + optimize_ctx = torchao_optimize_ctx(args.quantization) else: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) experiment = speedup_experiment