diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 6ea7a31a39..f40f403969 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -3485,6 +3485,18 @@ def get_example_inputs(self): action="store_true", help="Measure speedup with TorchInductor", ) + group.add_argument( + "--quantization", + choices=[ + "int8dynamic", + "int8weightonly", + "int4weightonly", + "autoquant", + "noquant", + ], + default=None, + help="Measure speedup of torchao quantization with TorchInductor baseline", + ) group.add_argument( "--export", action="store_true", @@ -3679,6 +3691,9 @@ def run(runner, args, original_dir=None): if args.inductor: assert args.backend is None args.backend = "inductor" + if args.quantization: + assert args.backend is None + args.backend = "torchao" if args.dynamic_batch_only: args.dynamic_shapes = True torch._dynamo.config.assume_static_by_default = True @@ -3957,6 +3972,20 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) + elif args.backend == "torchao": + assert "cuda" in args.devices, "Quantization requires CUDA device." + assert args.bfloat16, "Quantization requires dtype bfloat16." + from .torchao import setup_baseline, torchao_optimize_ctx + + setup_baseline() + baseline_ctx = functools.partial( + torch.compile, + backend="inductor", + fullgraph=args.nopython, + mode=args.inductor_compile_mode, + ) + runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) + optimize_ctx = torchao_optimize_ctx(args.quantization) else: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) experiment = speedup_experiment