Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torchao quantization in framework and group_bench #2116

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="Enable max autotune gemm"
)
parser.add_argument(
"--inductor-compile-mode",
default=None,
choices=['max-autotune'],
help="torch.compile mode argument for inductor runs.",
)
parser.add_argument(
"--torchinductor_enable_split_cat_fx_pass",
action='store_true',
Expand All @@ -104,6 +110,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
default="false",
help="Enable triton code dump by setting torch._inductor.config.debug",
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
help="Apply quantization to the model before running it",
)
args, extra_args = parser.parse_known_args(dynamo_args)
return args, extra_args

Expand All @@ -123,6 +134,9 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar

if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
if args.inductor_compile_mode == "max-autotune":
torchinductor.config.max_autotune = True
torchinductor.config.triton.cudagraphs = True
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
if bool(args.torchinductor_post_grad_batch_fusion):
torchinductor.config.post_grad_fusion_options["batch_linear_post_grad"] = {}
Expand All @@ -146,6 +160,26 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
torchinductor.config.triton.unique_kernel_names = True
if args.torchinductor_enable_max_autotune_gemm:
torchinductor.config.max_autotune_gemm = True
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors
)
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)

# used for correctness checks, to avoid triton rand() behaving differently from torch rand().
torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)
Expand Down
13 changes: 13 additions & 0 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model: "*"
test: eval
device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HDCharles In this config, we define the baseline extra_args being --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune, so it will apply this to every test_group/subgroup defined below.

metrics:
- latencies
Copy link
Contributor Author

@xuzhao9 xuzhao9 Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add CPU/GPU peak memory, add cpu_peak_mem and gpu_peak_mem here. @HDCharles

test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As shown in D52802534 test plan, here are the test results:

Running TorchBenchModelConfig(name='resnet50', test='eval', device='cuda', batch_size=None, extra_args=['--precision', 'bf16', '--torchdynamo', 'inductor', '--inductor-compile-mode', 'max-autotune', '--quantization', 'int8dynamic'], extra_env=None, output_dir=None) ... [done]
Running TorchBenchModelConfig(name='resnet50', test='eval', device='cuda', batch_size=None, extra_args=['--precision', 'bf16', '--torchdynamo', 'inductor', '--inductor-compile-mode', 'max-autotune', '--quantization', 'int8weightonly'], extra_env=None, output_dir=None) ... [done]
Running TorchBenchModelConfig(name='resnet50', test='eval', device='cuda', batch_size=None, extra_args=['--precision', 'bf16', '--torchdynamo', 'inductor', '--inductor-compile-mode', 'max-autotune', '--quantization', 'int4weightonly'], extra_env=None, output_dir=None) ... [done]

They are all running with compiler enabled.

- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
57 changes: 36 additions & 21 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@dataclasses.dataclass
class TorchBenchGroupBenchConfig:
baseline_config: TorchBenchModelConfig
baseline_configs: List[TorchBenchModelConfig]
metrics: List[str]
group_configs: Dict[str, List[TorchBenchModelConfig]]

Expand All @@ -37,7 +37,7 @@ def configs(self):

def config_to_str(config: TorchBenchModelConfig) -> str:
metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \
f" bs={config.batch_size}, extra_args={config.extra_args}"
f" bs={config.batch_size}, extra_args={' '.join(config.extra_args)}"
return metrics_base

def str_to_config(metric_name: str) -> TorchBenchModelConfig:
Expand Down Expand Up @@ -80,10 +80,10 @@ def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Pa
return result

def get_metrics(metrics: List[str], config: TorchBenchModelConfig) -> List[str]:
if metrics:
return metrics
if "--accuracy" in config.extra_args:
return ["accuracy"]
if metrics:
return metrics
return ["latencies", "cpu_peak_mem", "gpu_peak_mem"]

def validate(candidates: List[str], choices: List[str]) -> List[str]:
Expand Down Expand Up @@ -150,37 +150,52 @@ def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryru
print(" [oserror]", flush=True)
return {"accuracy": str(e)}

def models_from_config(config) -> List[str]:
assert "model" in config, f"We expect users to define models in config file."
if isinstance(config["model"], str):
if config["model"] == "*":
return list_models()
else:
return [config["model"]]
assert isinstance(config["model", list]), "Config model must be a list or string."
return config["model"]

def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig:
if not os.path.exists(config_file):
config_file = os.path.join(DEFAULT_CONFIG_DIR, config_file)
with open(config_file, "r") as fp:
data = yaml.safe_load(fp)
baseline_config = TorchBenchModelConfig(
name=data["model"],
test=data["test"],
device=data["device"],
batch_size=data["batch_size"] if "batch_size" in data else None,
extra_args=data["extra_args"].split(" ") if "extra_args" in data else [],
)
baseline_configs = [
TorchBenchModelConfig(
name=model,
test=data["test"],
device=data["device"],
batch_size=data["batch_size"] if "batch_size" in data else None,
extra_args=data["extra_args"].split(" ") if "extra_args" in data else [],
) for model in models_from_config(data)
]
metrics = data["metrics"] if "metrics" in data else []
group_configs = {}
for group_name in data["test_group"]:
group_configs[group_name] = []
group_extra_args = data["test_group"][group_name]["extra_args"].split(" ")
subgroup_extra_args_list = list(map(lambda x: x["extra_args"].split(" "), data["test_group"][group_name]["subgroup"]))
for subgroup_extra_args in subgroup_extra_args_list:
subgroup_config = copy.deepcopy(baseline_config)
subgroup_config.extra_args.extend(group_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args)
group_configs[group_name].append(subgroup_config)
return TorchBenchGroupBenchConfig(baseline_config, metrics, group_configs)
group_configs[group_name] = []
group_extra_args = list(filter(lambda x: bool(x), data["test_group"][group_name].get("extra_args", "").split(" ")))
for subgroup in data["test_group"][group_name]["subgroup"]:
subgroup_extra_args = subgroup.get("extra_args", "")
subgroup_extra_args = "" if subgroup_extra_args == None else subgroup_extra_args
subgroup_extra_args_list = list(filter(lambda x: bool(x), subgroup_extra_args.split(" ")))
for baseline_config in baseline_configs:
subgroup_config = copy.deepcopy(baseline_config)
subgroup_config.extra_args.extend(group_extra_args)
subgroup_config.extra_args.extend(subgroup_extra_args_list)
group_configs[group_name].append(subgroup_config)
return TorchBenchGroupBenchConfig(baseline_configs, metrics, group_configs)

def parse_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", required=True, help="YAML config to specify group of tests to run.")
parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.")
parser.add_argument("--debug", action="store_true", help="Save the debug output.")
parser.add_argument("--output", default=f"/tmp/{BM_NAME}", help="Output torchbench userbenchmark metrics file path.")
parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Output torchbench userbenchmark metrics file path.")
return parser.parse_args(args)

def run(args: List[str]):
Expand Down
12 changes: 8 additions & 4 deletions userbenchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,14 @@ def get_output_json(bm_name, metrics) -> Dict[str, Any]:
}


def get_output_dir(bm_name) -> Path:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
target_dir.mkdir(exist_ok=True, parents=True)
def get_output_dir(bm_name: str) -> Path:
import torch
IS_FBCODE = False if hasattr(torch.version, "git_version") else True
if not IS_FBCODE:
current_dir = Path(os.path.dirname(os.path.abspath(__file__)))
target_dir = current_dir.parent.joinpath(USERBENCHMARK_OUTPUT_PREFIX, bm_name)
else:
target_dir = Path(f"/tmp/{bm_name}")
return target_dir


Expand Down
Loading