Skip to content

Commit

Permalink
testing autoquant
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
HDCharles committed May 14, 2024
1 parent 67ef897 commit 0fbdca7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
20 changes: 18 additions & 2 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--quantization",
choices=["int8dynamic", "int8weightonly", "int4weightonly"],
choices=["int8dynamic", "int8weightonly", "int4weightonly","autoquant"],
help="Apply quantization to the model before running it",
)
parser.add_argument(
Expand Down Expand Up @@ -187,21 +187,37 @@ def apply_torchdynamo_args(
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,

)


torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.cache_size_limit = 1000
assert "cuda" in model.device
module, example_inputs = model.get_module()
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
if isinstance(example_inputs, tuple([tuple, list])):
example_inputs = tuple([
x.to(torch.bfloat16)
if isinstance(x, torch.Tensor) and x.dtype in [torch.float32, torch.float16]
else x
for x in example_inputs
])
module=module.to(torch.bfloat16)
with torch.no_grad():
module(*example_inputs)
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
elif args.quantization == "autoquant":
torchao.autoquant(module, example_inputs, error_on_unseen=False)


if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down
7 changes: 3 additions & 4 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
metrics:
- latencies
- gpu_peak_mem
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization noquant
- extra_args: --quantization autoquant
20 changes: 13 additions & 7 deletions userbenchmark/group_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def models_from_config(config) -> List[str]:
basic_models_list = list_models()
else:
basic_models_list = [config["model"]]
breakpoint()
assert isinstance(config["model", list]), "Config model must be a list or string."
basic_models_list = config["model"]
extended_models_list = []
Expand Down Expand Up @@ -218,14 +219,19 @@ def run(args: List[str]):
results = {}
try:
for config in group_config.configs:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
try:
metrics = get_metrics(group_config.metrics, config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception as e:
metrics_dict = {}
config_str = config_to_str(config)
for metric in metrics_dict:
results[f"{config_str}, metric={metric}"] = metrics_dict[metric]
for metric in metrics:
results[f"{config_str}, metric={metric}"] = metrics_dict.get(metric, "err")
except KeyboardInterrupt:
print("User keyboard interrupted!")
result = get_output_json(BM_NAME, results)
Expand Down

0 comments on commit 0fbdca7

Please sign in to comment.