From 0fbdca7ef34bd5823950b4d9c9e959ac55673786 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 13 May 2024 20:36:03 -0700 Subject: [PATCH] testing autoquant Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchbenchmark/util/backends/torchdynamo.py | 20 +++++++++++++++++-- .../group_bench/configs/torch_ao.yaml | 7 +++---- userbenchmark/group_bench/run.py | 20 ++++++++++++------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 8fa8127d00..e39315990e 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: ) parser.add_argument( "--quantization", - choices=["int8dynamic", "int8weightonly", "int4weightonly"], + choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant"], help="Apply quantization to the model before running it", ) parser.add_argument( @@ -187,21 +187,37 @@ def apply_torchdynamo_args( change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int8_woqtensors, + ) + torch._dynamo.config.automatic_dynamic_shapes = False torch._dynamo.config.force_parameter_static_shapes = False torch._dynamo.config.cache_size_limit = 1000 assert "cuda" in model.device module, example_inputs = model.get_module() + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True + if isinstance(example_inputs, tuple([tuple, list])): + example_inputs = tuple([ + x.to(torch.bfloat16) + if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16] + else x + for x in example_inputs + ]) + module=module.to(torch.bfloat16) + with torch.no_grad(): + module(*example_inputs) if args.quantization == "int8dynamic": - torch._inductor.config.force_fuse_int_mm_with_mul = True change_linear_weights_to_int8_dqtensors(module) elif args.quantization == "int8weightonly": torch._inductor.config.use_mixed_mm = True change_linear_weights_to_int8_woqtensors(module) elif args.quantization == "int4weightonly": change_linear_weights_to_int4_woqtensors(module) + elif args.quantization == "autoquant": + torchao.autoquant(module, example_inputs, error_on_unseen=False) + if args.freeze_prepack_weights: torch._inductor.config.freezing = True diff --git a/userbenchmark/group_bench/configs/torch_ao.yaml b/userbenchmark/group_bench/configs/torch_ao.yaml index 762668ea3f..78c206d54c 100644 --- a/userbenchmark/group_bench/configs/torch_ao.yaml +++ b/userbenchmark/group_bench/configs/torch_ao.yaml @@ -7,10 +7,9 @@ device: cuda extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune metrics: - latencies + - gpu_peak_mem test_group: test_batch_size_default: subgroup: - - extra_args: - - extra_args: --quantization int8dynamic - - extra_args: --quantization int8weightonly - - extra_args: --quantization int4weightonly + - extra_args: --quantization noquant + - extra_args: --quantization autoquant diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index 16a7f490f3..4221f4236d 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -158,6 +158,7 @@ def models_from_config(config) -> List[str]: basic_models_list = list_models() else: basic_models_list = [config["model"]] + breakpoint() assert isinstance(config["model", list]), "Config model must be a list or string." basic_models_list = config["model"] extended_models_list = [] @@ -218,14 +219,19 @@ def run(args: List[str]): results = {} try: for config in group_config.configs: - metrics = get_metrics(group_config.metrics, config) - if "accuracy" in metrics: - metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun) - else: - metrics_dict = run_config(config, metrics, dryrun=args.dryrun) + try: + metrics = get_metrics(group_config.metrics, config) + if "accuracy" in metrics: + metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun) + else: + metrics_dict = run_config(config, metrics, dryrun=args.dryrun) + except KeyboardInterrupt: + raise KeyboardInterrupt + except Exception as e: + metrics_dict = {} config_str = config_to_str(config) - for metric in metrics_dict: - results[f"{config_str}, metric={metric}"] = metrics_dict[metric] + for metric in metrics: + results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err") except KeyboardInterrupt: print("User keyboard interrupted!") result = get_output_json(BM_NAME, results)