diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 650da74bf8..9f0e12bf7c 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action='store_true', help="set to generate unique triton kernel names in Inductor" ) + parser.add_argument( + "--torchinductor_post_grad_batch_fusion", + type=distutils.util.strtobool, + help="Enable BMM Linear Fusion." + ) parser.add_argument( "--dynamo_disable_optimizer_step", type=distutils.util.strtobool, @@ -119,6 +124,8 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchdynamo == "inductor": import torch._inductor as torchinductor torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph) + if bool(args.torchinductor_post_grad_batch_fusion): + torchinductor.config.post_grad_fusion_options["batch_linear"] = {} torch._inductor.config.debug = bool(args.dump_triton) # Setup torchinductor.config.triton.mm