Skip to content

Commit

Permalink
Enable torchao quantization in framework and group_bench (#2116)
Browse files Browse the repository at this point in the history
Summary:

Support torchao quantization code in the framework.

Add a new config `torch_ao.yaml` in the group_bench userbenchmark.

Differential Revision: D52802534
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jan 23, 2024
1 parent c1148b9 commit b5d3f0b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 25 deletions.
34 changes: 34 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="Enable max autotune gemm"
)
parser.add_argument(
"--inductor-compile-mode",
default=None,
choices=['max-autotune'],
help="torch.compile mode argument for inductor runs.",
)
parser.add_argument(
"--torchinductor_enable_split_cat_fx_pass",
action='store_true',
Expand All @@ -104,6 +110,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
default="false",
help="Enable triton code dump by setting torch._inductor.config.debug",
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
help="Apply quantization to the model before running it",
)
args, extra_args = parser.parse_known_args(dynamo_args)
return args, extra_args

Expand All @@ -123,6 +134,9 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar

if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
if args.inductor_compile_mode == "max-autotune":
torchinductor.config.max_autotune = True
torchinductor.config.triton.cudagraphs = True
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
if bool(args.torchinductor_post_grad_batch_fusion):
torchinductor.config.post_grad_fusion_options["batch_linear_post_grad"] = {}
Expand All @@ -146,6 +160,26 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
torchinductor.config.triton.unique_kernel_names = True
if args.torchinductor_enable_max_autotune_gemm:
torchinductor.config.max_autotune_gemm = True
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors
)
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)

# used for correctness checks, to avoid triton rand() behaving differently from torch rand().
torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)
Expand Down
13 changes: 13 additions & 0 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model: "*"
test: eval
device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
57 changes: 36 additions & 21 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@dataclasses.dataclass
class TorchBenchGroupBenchConfig:
baseline_config: TorchBenchModelConfig
baseline_configs: List[TorchBenchModelConfig]
metrics: List[str]
group_configs: Dict[str, List[TorchBenchModelConfig]]

Expand All @@ -37,7 +37,7 @@ def configs(self):

def config_to_str(config: TorchBenchModelConfig) -> str:
metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \
f" bs={config.batch_size}, extra_args={config.extra_args}"
f" bs={config.batch_size}, extra_args={' '.join(config.extra_args)}"
return metrics_base

def str_to_config(metric_name: str) -> TorchBenchModelConfig:
Expand Down Expand Up @@ -80,10 +80,10 @@ def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Pa
return result

def get_metrics(metrics: List[str], config: TorchBenchModelConfig) -> List[str]:
if metrics:
return metrics
if "--accuracy" in config.extra_args:
return ["accuracy"]
if metrics:
return metrics
return ["latencies", "cpu_peak_mem", "gpu_peak_mem"]

def validate(candidates: List[str], choices: List[str]) -> List[str]:
Expand Down Expand Up @@ -150,37 +150,52 @@ def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryru
print(" [oserror]", flush=True)
return {"accuracy": str(e)}

def models_from_config(config) -> List[str]:
assert "model" in config, f"We expect users to define models in config file."
if isinstance(config["model"], str):
if config["model"] == "*":
return list_models()
else:
return [config["model"]]
assert isinstance(config["model", list]), "Config model must be a list or string."
return config["model"]

def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig:
if not os.path.exists(config_file):
config_file = os.path.join(DEFAULT_CONFIG_DIR, config_file)
with open(config_file, "r") as fp:
data = yaml.safe_load(fp)
baseline_config = TorchBenchModelConfig(
name=data["model"],
test=data["test"],
device=data["device"],
batch_size=data["batch_size"] if "batch_size" in data else None,
extra_args=data["extra_args"].split(" ") if "extra_args" in data else [],
)
baseline_configs = [
TorchBenchModelConfig(
name=model,
test=data["test"],
device=data["device"],
batch_size=data["batch_size"] if "batch_size" in data else None,
extra_args=data["extra_args"].split(" ") if "extra_args" in data else [],
) for model in models_from_config(data)
]
metrics = data["metrics"] if "metrics" in data else []
group_configs = {}
for group_name in data["test_group"]:
group_configs[group_name] = []
group_extra_args = data["test_group"][group_name]["extra_args"].split(" ")
subgroup_extra_args_list = list(map(lambda x: x["extra_args"].split(" "), data["test_group"][group_name]["subgroup"]))
for subgroup_extra_args in subgroup_extra_args_list:
subgroup_config = copy.deepcopy(baseline_config)
subgroup_config.extra_args.extend(group_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args)
group_configs[group_name].append(subgroup_config)
return TorchBenchGroupBenchConfig(baseline_config, metrics, group_configs)
group_configs[group_name] = []
group_extra_args = list(filter(lambda x: bool(x), data["test_group"][group_name].get("extra_args", "").split(" ")))
for subgroup in data["test_group"][group_name]["subgroup"]:
subgroup_extra_args = subgroup.get("extra_args", "")
subgroup_extra_args = "" if subgroup_extra_args == None else subgroup_extra_args
subgroup_extra_args_list = list(filter(lambda x: bool(x), subgroup_extra_args.split(" ")))
for baseline_config in baseline_configs:
subgroup_config = copy.deepcopy(baseline_config)
subgroup_config.extra_args.extend(group_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args_list)
group_configs[group_name].append(subgroup_config)
return TorchBenchGroupBenchConfig(baseline_configs, metrics, group_configs)

def parse_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", required=True, help="YAML config to specify group of tests to run.")
parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.")
parser.add_argument("--debug", action="store_true", help="Save the debug output.")
parser.add_argument("--output", default=f"/tmp/{BM_NAME}", help="Output torchbench userbenchmark metrics file path.")
parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Output torchbench userbenchmark metrics file path.")
return parser.parse_args(args)

def run(args: List[str]):
Expand Down
12 changes: 8 additions & 4 deletions userbenchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,14 @@ def get_output_json(bm_name, metrics) -> Dict[str, Any]:
}


def get_output_dir(bm_name) -> Path:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
target_dir.mkdir(exist_ok=True, parents=True)
def get_output_dir(bm_name: str) -> Path:
import torch
IS_FBCODE = False if hasattr(torch.version, "git_version") else True
if not IS_FBCODE:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
else:
target_dir = Path(f"/tmp/{bm_name}")
return target_dir


Expand Down

0 comments on commit b5d3f0b

Please sign in to comment.