Skip to content

Commit

Permalink
Add test script for pt2 batch fusion (pytorch#2018)
Browse files Browse the repository at this point in the history
Summary:

Add test case and torchinductor option for pt2 batch fusion

Reviewed By: ckluk2

Differential Revision: D47609030
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 29, 2023
1 parent 827606c commit 971bcbd
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="set to generate unique triton kernel names in Inductor"
)
parser.add_argument(
"--torchinductor_post_grad_batch_fusion",
type=distutils.util.strtobool,
help="Enable BMM Linear Fusion."
)
parser.add_argument(
"--dynamo_disable_optimizer_step",
type=distutils.util.strtobool,
Expand Down Expand Up @@ -119,6 +124,8 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
if bool(args.torchinductor_post_grad_batch_fusion):
torchinductor.config.post_grad_fusion_options["batch_fusion"] = {}
torch._inductor.config.debug = bool(args.dump_triton)

# Setup torchinductor.config.triton.mm
Expand Down

0 comments on commit 971bcbd

Please sign in to comment.