Skip to content

Commit

Permalink
Add autotune_max_gemm and dump_triton options
Browse files Browse the repository at this point in the history
Summary: As the title says

Reviewed By: luccafong

Differential Revision: D50703503

fbshipit-source-id: 46c8e28c51f61ad766ac1575ad0a5289ffe1feeb
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 26, 2023
1 parent f253464 commit edf7115
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
action='store_true',
help="enable batch fusion in Inductor"
)
parser.add_argument(
"--torchinductor_enable_max_autotune_gemm",
action='store_true',
help="Enable max autotune gemm"
)
parser.add_argument(
"--torchinductor_enable_split_cat_fx_pass",
action='store_true',
Expand All @@ -88,6 +93,12 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
type=distutils.util.strtobool,
default="false",
)
parser.add_argument(
"--dump_triton",
type=distutils.util.strtobool,
default="false",
help="Enable triton code dump by setting torch._inductor.config.debug",
)
args, extra_args = parser.parse_known_args(dynamo_args)
return args, extra_args

Expand All @@ -108,6 +119,7 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if args.torchdynamo == "inductor":
import torch._inductor as torchinductor
torchinductor.config.triton.cudagraphs = bool(args.torchinductor_cudagraph)
torch._inductor.config.debug = bool(args.dump_triton)

# Setup torchinductor.config.triton.mm
if args.tritonmm == "triton":
Expand All @@ -125,6 +137,8 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
torchinductor.config.split_cat_fx_passes = True
if args.torchinductor_triton_unique_kernel_names:
torchinductor.config.triton.unique_kernel_names = True
if args.torchinductor_enable_max_autotune_gemm:
torchinductor.config.max_autotune_gemm = True

# used for correctness checks, to avoid triton rand() behaving differently from torch rand().
torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)
Expand Down

0 comments on commit edf7115

Please sign in to comment.