Skip to content

Commit

Permalink
Enable torchao quantization in framework and group_bench
Browse files Browse the repository at this point in the history
Summary:
Support torchao quantization code in the framework.

Add a new config `torch_ao.yaml` in the group_bench userbenchmark.

Differential Revision: D52802534
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jan 16, 2024
1 parent 6a8b941 commit 8fec3de
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 19 deletions.
34 changes: 34 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="Enable max autotune gemm"
)
parser.add_argument(
"--inductor-compile-mode",
default=None,
choices=['max-autotune'],
help="torch.compile mode argument for inductor runs.",
)
parser.add_argument(
"--torchinductor_enable_split_cat_fx_pass",
action='store_true',
Expand All @@ -104,6 +110,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
default="false",
help="Enable triton code dump by setting torch._inductor.config.debug",
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
help="Apply quantization to the model before running it",
)
args, extra_args = parser.parse_known_args(dynamo_args)
return args, extra_args

Expand All @@ -123,6 +134,9 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar

if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
if args.inductor_compile_mode == "max-autotune":
torchinductor.config.max_autotune = True
torchinductor.config.triton.cudagraphs = True
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
if bool(args.torchinductor_post_grad_batch_fusion):
torchinductor.config.post_grad_fusion_options["batch_linear_post_grad"] = {}
Expand All @@ -146,6 +160,26 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
torchinductor.config.triton.unique_kernel_names = True
if args.torchinductor_enable_max_autotune_gemm:
torchinductor.config.max_autotune_gemm = True
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors
)
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)

# used for correctness checks, to avoid triton rand() behaving differently from torch rand().
torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dar
if dargs.channels_last:
model.enable_channels_last()
if dargs.precision == "fp16":
model.enable_fp16()
model.cast_to_fp16()
elif dargs.precision == "bf16":
model.enable_bf16()
model.cast_to_bf16()
elif dargs.precision == "tf32":
import torch
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down
8 changes: 4 additions & 4 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,13 @@ def _cast_to(self, cond, action):
else:
warnings.warn(UserWarning(f"{model_name} example inputs doesn't cast to {action} yet!"))

def enable_bf16(self):
tensor_cond = lambda x: x.dtype == torch.float32
def cast_to_bf16(self):
tensor_cond = lambda x: x.is_floating_point()
tensor_action = lambda x: x.to(torch.bfloat16)
self._cast_to(tensor_cond, tensor_action)

def enable_fp16(self):
tensor_cond = lambda x: x.dtype == torch.float32
def cast_to_fp16(self):
tensor_cond = lambda x: x.is_floating_point()
tensor_action = lambda x: x.half()
self._cast_to(tensor_cond, tensor_action)

Expand Down
13 changes: 13 additions & 0 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model: resnet50
test: eval
device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
18 changes: 10 additions & 8 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def configs(self):

def config_to_str(config: TorchBenchModelConfig) -> str:
metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \
f" bs={config.batch_size}, extra_args={config.extra_args}"
f" bs={config.batch_size}, extra_args={' '.join(config.extra_args)}"
return metrics_base

def str_to_config(metric_name: str) -> TorchBenchModelConfig:
Expand Down Expand Up @@ -80,10 +80,10 @@ def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Pa
return result

def get_metrics(metrics: List[str], config: TorchBenchModelConfig) -> List[str]:
if metrics:
return metrics
if "--accuracy" in config.extra_args:
return ["accuracy"]
if metrics:
return metrics
return ["latencies", "cpu_peak_mem", "gpu_peak_mem"]

def validate(candidates: List[str], choices: List[str]) -> List[str]:
Expand Down Expand Up @@ -166,12 +166,14 @@ def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig:
group_configs = {}
for group_name in data["test_group"]:
group_configs[group_name] = []
group_extra_args = data["test_group"][group_name]["extra_args"].split(" ")
subgroup_extra_args_list = list(map(lambda x: x["extra_args"].split(" "), data["test_group"][group_name]["subgroup"]))
for subgroup_extra_args in subgroup_extra_args_list:
group_extra_args = list(filter(lambda x: bool(x), data["test_group"][group_name].get("extra_args", "").split(" ")))
for subgroup in data["test_group"][group_name]["subgroup"]:
subgroup_extra_args = subgroup.get("extra_args", "")
subgroup_extra_args = "" if subgroup_extra_args == None else subgroup_extra_args
subgroup_extra_args_list = list(filter(lambda x: bool(x), subgroup_extra_args.split(" ")))
subgroup_config = copy.deepcopy(baseline_config)
subgroup_config.extra_args.extend(group_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args_list)
group_configs[group_name].append(subgroup_config)
return TorchBenchGroupBenchConfig(baseline_config, metrics, group_configs)

Expand All @@ -180,7 +182,7 @@ def parse_args(args: List[str]):
parser.add_argument("--config", "-c", required=True, help="YAML config to specify group of tests to run.")
parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.")
parser.add_argument("--debug", action="store_true", help="Save the debug output.")
parser.add_argument("--output", default=f"/tmp/{BM_NAME}", help="Output torchbench userbenchmark metrics file path.")
parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Output torchbench userbenchmark metrics file path.")
return parser.parse_args(args)

def run(args: List[str]):
Expand Down
13 changes: 8 additions & 5 deletions userbenchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def get_output_json(bm_name, metrics) -> Dict[str, Any]:
}


def get_output_dir(bm_name) -> Path:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
target_dir.mkdir(exist_ok=True, parents=True)
def get_output_dir(bm_name: str) -> Path:
import torch
IS_FBCODE = False if hasattr(torch.version, "git_version") else True
if not IS_FBCODE:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
else:
target_dir = Path(f"/tmp/{bm_name}")
return target_dir


def get_default_output_json_path(bm_name: str, target_dir: Path=None) -> str:
if target_dir is None:
target_dir = get_output_dir(bm_name)
Expand Down

0 comments on commit 8fec3de

Please sign in to comment.