From a1f4b2e808124b645224b71aa4b4dc0db593c7e1 Mon Sep 17 00:00:00 2001 From: FindHao Date: Fri, 4 Oct 2024 14:07:49 -0700 Subject: [PATCH] Add multiple ops support for --op argument (#2490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Allow users benchmark multiple ops in a single run. The ops can be split by commas, `--op fp8_gemm,addmm` Example output: ``` % python run_benchmark.py triton --op fp8_gemm,addmm --num-inputs 1 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 3.12s/it] x_val torch_fp8_gemm-gbps torch_fp8_gemm-gbps torch_fp8_gemm-latency torch_fp8_gemm-tflops triton_fp8_gemm-gbps triton_fp8_gemm-gbps triton_fp8_gemm-latency triton_fp8_gemm-tflops ------------------ --------------------- --------------------- ------------------------ ----------------------- ---------------------- ---------------------- ------------------------- ------------------------ (1024, 1024, 1024) 462.202 462.202 0.00907462 236.647 630.43 630.43 0.00665309 322.78 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00, 5.90s/it] (M, N, K) aten_addmm-best_config aten_addmm-gbps aten_addmm-tflops triton_addmm-best_config triton_addmm-gbps triton_addmm-tflops pt2_triton_matmul-best_config pt2_triton_matmul-gbps pt2_triton_matmul-tflops ------------------ ------------------------ ----------------- ------------------- ------------------------------------------------------------------------------------------------------------- ------------------- --------------------- ------------------------------- ------------------------ -------------------------- (20120, 512, 1536) 818.112 247.544 {'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8, 'num_warps': 8, 'num_ctas': 1, 'num_stages': 3} 911.569 275.823 889.125 269.031 ``` Pull Request resolved: https://github.com/pytorch/benchmark/pull/2490 Reviewed By: xuzhao9 Differential Revision: D63862548 Pulled By: FindHao fbshipit-source-id: 9d4afa6051d4191bc2e3288f59e2820627647b91 --- userbenchmark/triton/run.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index f8ab0821a..ad2c58c2a 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -29,7 +29,12 @@ def get_parser(args=None): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--op", type=str, required=False, help="Operator to benchmark.") + parser.add_argument( + "--op", + type=str, + required=False, + help="Operators to benchmark. Split with comma if multiple.", + ) parser.add_argument( "--mode", choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"], @@ -188,5 +193,11 @@ def run(args: List[str] = []): run_ci() return + if args.op: + ops = args.op.split(",") + else: + ops = [] with gpu_lockdown(args.gpu_lockdown): - _run(args, extra_args) + for op in ops: + args.op = op + _run(args, extra_args)