From e95390a1f942b08d1b343a0ceec141efdeddc40c Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 9 Sep 2024 17:19:00 -0700 Subject: [PATCH] Turn on TMA by default for row-wise GEMM (#2450) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2450 X-link: https://github.com/facebookresearch/FBGEMM/pull/189 Enabling the TMA row-wise GEMM by default it TMA appears to give quite some speedup across-the-board, up to 40% for some shapes. Reviewed By: choutim Differential Revision: D62212842 fbshipit-source-id: 59220cec90e222fe91be9f53a3477f1c38e02e2a --- torchbenchmark/operators/fp8_gemm_rowwise/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py index 482949c78..eddc32f05 100644 --- a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py +++ b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py @@ -24,7 +24,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: parser.add_argument( "--no_fp8_fast_accum", dest="fp8_fast_accum", action="store_false" ) - parser.add_argument("--use_tma", action="store_true") + parser.add_argument("--no_use_tma", dest="use_tma", action="store_false") args = parser.parse_args(args) return args