From 4cda19903b848133fbe056120944339ca28eca53 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 23 Jan 2024 14:51:39 -0800 Subject: [PATCH] Enable torchao quantization in framework and group_bench (#2116) Summary: Support torchao quantization code in the framework. Add a new config `torch_ao.yaml` in the group_bench userbenchmark. Differential Revision: D52802534 --- torchbenchmark/util/backends/torchdynamo.py | 34 +++++++++++ .../group_bench/configs/torch_ao.yaml | 13 +++++ userbenchmark/group_bench/run.py | 57 ++++++++++++------- userbenchmark/utils.py | 12 ++-- 4 files changed, 91 insertions(+), 25 deletions(-) create mode 100644 userbenchmark/group_bench/configs/torch_ao.yaml diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index c8835e86f6..18c2d9220c 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -78,6 +78,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action='store_true', help="Enable max autotune gemm" ) + parser.add_argument( + "--inductor-compile-mode", + default=None, + choices=['max-autotune'], + help="torch.compile mode argument for inductor runs.", + ) parser.add_argument( "--torchinductor_enable_split_cat_fx_pass", action='store_true', @@ -104,6 +110,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: default="false", help="Enable triton code dump by setting torch._inductor.config.debug", ) + parser.add_argument( + "--quantization", + choices=["int8dynamic", "int8weightonly", "int4weightonly"], + help="Apply quantization to the model before running it", + ) args, extra_args = parser.parse_known_args(dynamo_args) return args, extra_args @@ -123,6 +134,9 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchdynamo == "inductor": import torch._inductor as torchinductor + if args.inductor_compile_mode == "max-autotune": + torchinductor.config.max_autotune = True + torchinductor.config.triton.cudagraphs = True torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph) if bool(args.torchinductor_post_grad_batch_fusion): torchinductor.config.post_grad_fusion_options["batch_linear_post_grad"] = {} @@ -146,6 +160,26 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar torchinductor.config.triton.unique_kernel_names = True if args.torchinductor_enable_max_autotune_gemm: torchinductor.config.max_autotune_gemm = True + if args.quantization: + import torchao + from torchao.quantization import ( + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors + ) + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 1000 + assert "cuda" in model.device + module, example_inputs = model.get_module() + if args.quantization == "int8dynamic": + torch._inductor.config.force_fuse_int_mm_with_mul = True + change_linear_weights_to_int8_dqtensors(module) + elif args.quantization == "int8weightonly": + torch._inductor.config.use_mixed_mm = True + change_linear_weights_to_int8_woqtensors(module) + elif args.quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) # used for correctness checks, to avoid triton rand() behaving differently from torch rand(). torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random) diff --git a/userbenchmark/group_bench/configs/torch_ao.yaml b/userbenchmark/group_bench/configs/torch_ao.yaml new file mode 100644 index 0000000000..77cbec12e8 --- /dev/null +++ b/userbenchmark/group_bench/configs/torch_ao.yaml @@ -0,0 +1,13 @@ +model: "*" +test: eval +device: cuda +extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune +metrics: + - latencies +test_group: + test_batch_size_default: + subgroup: + - extra_args: + - extra_args: --quantization int8dynamic + - extra_args: --quantization int8weightonly + - extra_args: --quantization int4weightonly diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index 9df03b1452..f490d13036 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -27,7 +27,7 @@ @dataclasses.dataclass class TorchBenchGroupBenchConfig: - baseline_config: TorchBenchModelConfig + baseline_configs: List[TorchBenchModelConfig] metrics: List[str] group_configs: Dict[str, List[TorchBenchModelConfig]] @@ -37,7 +37,7 @@ def configs(self): def config_to_str(config: TorchBenchModelConfig) -> str: metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \ - f" bs={config.batch_size}, extra_args={config.extra_args}" + f" bs={config.batch_size}, extra_args={' '.join(config.extra_args)}" return metrics_base def str_to_config(metric_name: str) -> TorchBenchModelConfig: @@ -80,10 +80,10 @@ def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Pa return result def get_metrics(metrics: List[str], config: TorchBenchModelConfig) -> List[str]: - if metrics: - return metrics if "--accuracy" in config.extra_args: return ["accuracy"] + if metrics: + return metrics return ["latencies", "cpu_peak_mem", "gpu_peak_mem"] def validate(candidates: List[str], choices: List[str]) -> List[str]: @@ -150,37 +150,52 @@ def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryru print(" [oserror]", flush=True) return {"accuracy": str(e)} +def models_from_config(config) -> List[str]: + assert "model" in config, f"We expect users to define models in config file." + if isinstance(config["model"], str): + if config["model"] == "*": + return list_models() + else: + return [config["model"]] + assert isinstance(config["model", list]), "Config model must be a list or string." + return config["model"] + def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig: if not os.path.exists(config_file): config_file = os.path.join(DEFAULT_CONFIG_DIR, config_file) with open(config_file, "r") as fp: data = yaml.safe_load(fp) - baseline_config = TorchBenchModelConfig( - name=data["model"], - test=data["test"], - device=data["device"], - batch_size=data["batch_size"] if "batch_size" in data else None, - extra_args=data["extra_args"].split(" ") if "extra_args" in data else [], - ) + baseline_configs = [ + TorchBenchModelConfig( + name=model, + test=data["test"], + device=data["device"], + batch_size=data["batch_size"] if "batch_size" in data else None, + extra_args=data["extra_args"].split(" ") if "extra_args" in data else [], + ) for model in models_from_config(data) + ] metrics = data["metrics"] if "metrics" in data else [] group_configs = {} for group_name in data["test_group"]: - group_configs[group_name] = [] - group_extra_args = data["test_group"][group_name]["extra_args"].split(" ") - subgroup_extra_args_list = list(map(lambda x: x["extra_args"].split(" "), data["test_group"][group_name]["subgroup"])) - for subgroup_extra_args in subgroup_extra_args_list: - subgroup_config = copy.deepcopy(baseline_config) - subgroup_config.extra_args.extend(group_extra_args) - subgroup_config.extra_args.extend(subgroup_extra_args) - group_configs[group_name].append(subgroup_config) - return TorchBenchGroupBenchConfig(baseline_config, metrics, group_configs) + group_configs[group_name] = [] + group_extra_args = list(filter(lambda x: bool(x), data["test_group"][group_name].get("extra_args", "").split(" "))) + for subgroup in data["test_group"][group_name]["subgroup"]: + subgroup_extra_args = subgroup.get("extra_args", "") + subgroup_extra_args = "" if subgroup_extra_args == None else subgroup_extra_args + subgroup_extra_args_list = list(filter(lambda x: bool(x), subgroup_extra_args.split(" "))) + for baseline_config in baseline_configs: + subgroup_config = copy.deepcopy(baseline_config) + subgroup_config.extra_args.extend(group_extra_args) + subgroup_config.extra_args.extend(subgroup_extra_args_list) + group_configs[group_name].append(subgroup_config) + return TorchBenchGroupBenchConfig(baseline_configs, metrics, group_configs) def parse_args(args: List[str]): parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", required=True, help="YAML config to specify group of tests to run.") parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.") parser.add_argument("--debug", action="store_true", help="Save the debug output.") - parser.add_argument("--output", default=f"/tmp/{BM_NAME}", help="Output torchbench userbenchmark metrics file path.") + parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Output torchbench userbenchmark metrics file path.") return parser.parse_args(args) def run(args: List[str]): diff --git a/userbenchmark/utils.py b/userbenchmark/utils.py index 6d01d60345..6ae9813715 100644 --- a/userbenchmark/utils.py +++ b/userbenchmark/utils.py @@ -93,10 +93,14 @@ def get_output_json(bm_name, metrics) -> Dict[str, Any]: } -def get_output_dir(bm_name) -> Path: - current_dir = Path(os.path.dirname(os.path.abspath(__file__))) - target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name) - target_dir.mkdir(exist_ok=True, parents=True) +def get_output_dir(bm_name: str) -> Path: + import torch + IS_FBCODE = False if hasattr(torch.version, "git_version") else True + if not IS_FBCODE: + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name) + else: + target_dir = Path(f"/tmp/{bm_name}") return target_dir