From 8fec3dec2c5adcecf12af79c3b7e2b8101515690 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 16 Jan 2024 10:10:57 -0800 Subject: [PATCH] Enable torchao quantization in framework and group_bench Summary: Support torchao quantization code in the framework. Add a new config `torch_ao.yaml` in the group_bench userbenchmark. Differential Revision: D52802534 --- torchbenchmark/util/backends/torchdynamo.py | 34 +++++++++++++++++++ torchbenchmark/util/extra_args.py | 4 +-- torchbenchmark/util/model.py | 8 ++--- .../group_bench/configs/torch_ao.yaml | 13 +++++++ userbenchmark/group_bench/run.py | 18 +++++----- userbenchmark/utils.py | 13 ++++--- 6 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 userbenchmark/group_bench/configs/torch_ao.yaml diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index c8835e86f6..18c2d9220c 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -78,6 +78,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action='store_true', help="Enable max autotune gemm" ) + parser.add_argument( + "--inductor-compile-mode", + default=None, + choices=['max-autotune'], + help="torch.compile mode argument for inductor runs.", + ) parser.add_argument( "--torchinductor_enable_split_cat_fx_pass", action='store_true', @@ -104,6 +110,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: default="false", help="Enable triton code dump by setting torch._inductor.config.debug", ) + parser.add_argument( + "--quantization", + choices=["int8dynamic", "int8weightonly", "int4weightonly"], + help="Apply quantization to the model before running it", + ) args, extra_args = parser.parse_known_args(dynamo_args) return args, extra_args @@ -123,6 +134,9 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchdynamo == "inductor": import torch._inductor as torchinductor + if args.inductor_compile_mode == "max-autotune": + torchinductor.config.max_autotune = True + torchinductor.config.triton.cudagraphs = True torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph) if bool(args.torchinductor_post_grad_batch_fusion): torchinductor.config.post_grad_fusion_options["batch_linear_post_grad"] = {} @@ -146,6 +160,26 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar torchinductor.config.triton.unique_kernel_names = True if args.torchinductor_enable_max_autotune_gemm: torchinductor.config.max_autotune_gemm = True + if args.quantization: + import torchao + from torchao.quantization import ( + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors + ) + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 1000 + assert "cuda" in model.device + module, example_inputs = model.get_module() + if args.quantization == "int8dynamic": + torch._inductor.config.force_fuse_int_mm_with_mul = True + change_linear_weights_to_int8_dqtensors(module) + elif args.quantization == "int8weightonly": + torch._inductor.config.use_mixed_mm = True + change_linear_weights_to_int8_woqtensors(module) + elif args.quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) # used for correctness checks, to avoid triton rand() behaving differently from torch rand(). torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random) diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index d5062ebf05..a7f5f75f3f 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -90,9 +90,9 @@ def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dar if dargs.channels_last: model.enable_channels_last() if dargs.precision == "fp16": - model.enable_fp16() + model.cast_to_fp16() elif dargs.precision == "bf16": - model.enable_bf16() + model.cast_to_bf16() elif dargs.precision == "tf32": import torch torch.backends.cuda.matmul.allow_tf32 = True diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 8593ba4c90..aa8fc6890f 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -352,13 +352,13 @@ def _cast_to(self, cond, action): else: warnings.warn(UserWarning(f"{model_name} example inputs doesn't cast to {action} yet!")) - def enable_bf16(self): - tensor_cond = lambda x: x.dtype == torch.float32 + def cast_to_bf16(self): + tensor_cond = lambda x: x.is_floating_point() tensor_action = lambda x: x.to(torch.bfloat16) self._cast_to(tensor_cond, tensor_action) - def enable_fp16(self): - tensor_cond = lambda x: x.dtype == torch.float32 + def cast_to_fp16(self): + tensor_cond = lambda x: x.is_floating_point() tensor_action = lambda x: x.half() self._cast_to(tensor_cond, tensor_action) diff --git a/userbenchmark/group_bench/configs/torch_ao.yaml b/userbenchmark/group_bench/configs/torch_ao.yaml new file mode 100644 index 0000000000..3d0354c37b --- /dev/null +++ b/userbenchmark/group_bench/configs/torch_ao.yaml @@ -0,0 +1,13 @@ +model: resnet50 +test: eval +device: cuda +extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune +metrics: + - latencies +test_group: + test_batch_size_default: + subgroup: + - extra_args: + - extra_args: --quantization int8dynamic + - extra_args: --quantization int8weightonly + - extra_args: --quantization int4weightonly diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index 9df03b1452..5f67fa44e8 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -37,7 +37,7 @@ def configs(self): def config_to_str(config: TorchBenchModelConfig) -> str: metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \ - f" bs={config.batch_size}, extra_args={config.extra_args}" + f" bs={config.batch_size}, extra_args={' '.join(config.extra_args)}" return metrics_base def str_to_config(metric_name: str) -> TorchBenchModelConfig: @@ -80,10 +80,10 @@ def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Pa return result def get_metrics(metrics: List[str], config: TorchBenchModelConfig) -> List[str]: - if metrics: - return metrics if "--accuracy" in config.extra_args: return ["accuracy"] + if metrics: + return metrics return ["latencies", "cpu_peak_mem", "gpu_peak_mem"] def validate(candidates: List[str], choices: List[str]) -> List[str]: @@ -166,12 +166,14 @@ def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig: group_configs = {} for group_name in data["test_group"]: group_configs[group_name] = [] - group_extra_args = data["test_group"][group_name]["extra_args"].split(" ") - subgroup_extra_args_list = list(map(lambda x: x["extra_args"].split(" "), data["test_group"][group_name]["subgroup"])) - for subgroup_extra_args in subgroup_extra_args_list: + group_extra_args = list(filter(lambda x: bool(x), data["test_group"][group_name].get("extra_args", "").split(" "))) + for subgroup in data["test_group"][group_name]["subgroup"]: + subgroup_extra_args = subgroup.get("extra_args", "") + subgroup_extra_args = "" if subgroup_extra_args == None else subgroup_extra_args + subgroup_extra_args_list = list(filter(lambda x: bool(x), subgroup_extra_args.split(" "))) subgroup_config = copy.deepcopy(baseline_config) subgroup_config.extra_args.extend(group_extra_args) - subgroup_config.extra_args.extend(subgroup_extra_args) + subgroup_config.extra_args.extend(subgroup_extra_args_list) group_configs[group_name].append(subgroup_config) return TorchBenchGroupBenchConfig(baseline_config, metrics, group_configs) @@ -180,7 +182,7 @@ def parse_args(args: List[str]): parser.add_argument("--config", "-c", required=True, help="YAML config to specify group of tests to run.") parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.") parser.add_argument("--debug", action="store_true", help="Save the debug output.") - parser.add_argument("--output", default=f"/tmp/{BM_NAME}", help="Output torchbench userbenchmark metrics file path.") + parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Output torchbench userbenchmark metrics file path.") return parser.parse_args(args) def run(args: List[str]): diff --git a/userbenchmark/utils.py b/userbenchmark/utils.py index 6d01d60345..37ad835a48 100644 --- a/userbenchmark/utils.py +++ b/userbenchmark/utils.py @@ -93,13 +93,16 @@ def get_output_json(bm_name, metrics) -> Dict[str, Any]: } -def get_output_dir(bm_name) -> Path: - current_dir = Path(os.path.dirname(os.path.abspath(__file__))) - target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name) - target_dir.mkdir(exist_ok=True, parents=True) +def get_output_dir(bm_name: str) -> Path: + import torch + IS_FBCODE = False if hasattr(torch.version, "git_version") else True + if not IS_FBCODE: + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name) + else: + target_dir = Path(f"/tmp/{bm_name}") return target_dir - def get_default_output_json_path(bm_name: str, target_dir: Path=None) -> str: if target_dir is None: target_dir = get_output_dir(bm_name)