Skip to content

Commit

Permalink
Add test script for pt2 batch fusion
Browse files Browse the repository at this point in the history
Summary: Add test case and torchinductor option for pt2 batch fusion

Reviewed By: ckluk2

Differential Revision: D47609030
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 30, 2023
1 parent 619f159 commit 4d41320
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="set to generate unique triton kernel names in Inductor"
)
parser.add_argument(
"--torchinductor_batch_fusion",
type=distutils.util.strtobool,
help="Enable BMM Linear Fusion."
)
parser.add_argument(
"--dynamo_disable_optimizer_step",
type=distutils.util.strtobool,
Expand Down Expand Up @@ -119,6 +124,7 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
torchinductor.config.post_grad_batch_fusion = bool(args.torchinductor_batch_fusion)
torch._inductor.config.debug = bool(args.dump_triton)

# Setup torchinductor.config.triton.mm
Expand Down

0 comments on commit 4d41320

Please sign in to comment.