Skip to content

Commit

Permalink
Add test script for pt2 batch fusion (#2018)
Browse files Browse the repository at this point in the history
Summary:

Add test case and torchinductor option for pt2 batch fusion

Reviewed By: ckluk2

Differential Revision: D47609030
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 29, 2023
1 parent fba74bb commit 0a591c7
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="set to generate unique triton kernel names in Inductor"
)
parser.add_argument(
"--torchinductor_post_grad_batch_fusion",
type=distutils.util.strtobool,
help="Enable BMM Linear Fusion."
)
parser.add_argument(
"--dynamo_disable_optimizer_step",
type=distutils.util.strtobool,
Expand Down Expand Up @@ -119,6 +124,7 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
torchinductor.config.post_grad_fusion_options["batch_fusion"] = {}
torch._inductor.config.debug = bool(args.dump_triton)

# Setup torchinductor.config.triton.mm
Expand Down

0 comments on commit 0a591c7

Please sign in to comment.