From 2d4f714ee3eed33bf927eb07fcdcd9a3c1fec102 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Sat, 20 Jul 2024 20:56:42 -0400 Subject: [PATCH] Update GEMM list of problem sizes (#1662) The new list includes all problem sizes we reported in presentations. Signed-off-by: Whitney Tsang --- benchmarks/xetla_benchmark/gemm_benchmark.py | 29 ++++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/benchmarks/xetla_benchmark/gemm_benchmark.py b/benchmarks/xetla_benchmark/gemm_benchmark.py index 3e5c4ea82f..19421dd335 100644 --- a/benchmarks/xetla_benchmark/gemm_benchmark.py +++ b/benchmarks/xetla_benchmark/gemm_benchmark.py @@ -324,18 +324,35 @@ def matmul(a, b): @triton.testing.perf_report( triton.testing.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'M', 'N', 'K'], + x_names=['B', 'M', 'K', 'N'], # different possible values for `x_name` - x_vals=[[1, 256 * i, 256 * i, 256 * i] for i in range(1, 17)] + # - [[4, 32768, 128, 4096], # + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # + [[1, 1, 5120, 13824], # + [1, 4, 4096, 12288], # + [1, 512, 8192, 8192], # + [1, 512, 8192, 32768], # + [1, 512, 32768, 8192], # + [1, 1024, 16384, 8192], # + [1, 1024, 28672, 8192], # + [1, 3072, 4096, 3072], # + [1, 4096, 16384, 8192], # + [1, 8192, 16384, 1024], # + [1, 8192, 16384, 4096], # + [1, 16384, 1024, 8192], # + [1, 16384, 4096, 8192], # + [1, 16384, 8192, 1024], # + [1, 16384, 8192, 4096], # + [4, 32768, 128, 4096], # [4, 32768, 4096, 128], # - [32, 4096, 4096, 128]], + [32, 4096, 4096, 128], # + [4096, 8, 128, 16384], # + [4096, 8, 16384, 128]], line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` line_vals=['onednn', 'triton', 'xetla'], # label name for the lines - line_names=["oneDNN", "Triton", "Xetla"], + line_names=["oneDNN", "Triton", "XeTLA"], # line styles #styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis @@ -351,7 +368,7 @@ def benchmark(B, M, N, K, provider): a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16) b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16) - quantiles = [0.5, 0.2, 0.8] + quantiles = [0.5, 0.0, 1.0] # calculate tflops for oneDNN kernel def calculate_tflops(ms):