Skip to content

Commit

Permalink
Copy autotune config from 09 to 03 tutorial (#1185)
Browse files Browse the repository at this point in the history
<details>
  <summary>Performance with CUDA autotune config</summary>

```
matmul-performance-fp16:
         M       N       K     rocBLAS    Triton
0    256.0   256.0   256.0    4.559026  0.205603
1    384.0   384.0   384.0   11.603095  0.493577
2    512.0   512.0   512.0   23.301689  0.896219
3    640.0   640.0   640.0   36.817977  1.423458
4    768.0   768.0   768.0   42.573762  2.068046
5    896.0   896.0   896.0   58.386619  2.789804
6   1024.0  1024.0  1024.0   71.014673  2.060767
7   1152.0  1152.0  1152.0   86.082425  2.562904
8   1280.0  1280.0  1280.0  103.614232  3.077169
9   1408.0  1408.0  1408.0  119.083163  3.726118
10  1536.0  1536.0  1536.0  122.098335  4.454128
11  1664.0  1664.0  1664.0  141.506240  5.207327
12  1792.0  1792.0  1792.0  153.047478  5.939910
13  1920.0  1920.0  1920.0  148.945453  5.183291
14  2048.0  2048.0  2048.0  155.389555  5.532185
15  2176.0  2176.0  2176.0  150.986335  6.260973
16  2304.0  2304.0  2304.0  146.579467  6.856634
17  2432.0  2432.0  2432.0  167.260059  6.004895
18  2560.0  2560.0  2560.0  177.875489  6.270450
19  2688.0  2688.0  2688.0  172.422980  6.999930
20  2816.0  2816.0  2816.0  157.345508  6.461364
21  2944.0  2944.0  2944.0  166.466907  6.979531
22  3072.0  3072.0  3072.0  137.424294  6.542419
23  3200.0  3200.0  3200.0  169.256199  6.909815
24  3328.0  3328.0  3328.0  181.252667  6.959148
25  3456.0  3456.0  3456.0  176.704802  7.097506
26  3584.0  3584.0  3584.0  184.205672  7.128893
27  3712.0  3712.0  3712.0  188.513353  7.130193
28  3840.0  3840.0  3840.0  182.560943  6.989402
29  3968.0  3968.0  3968.0  182.978727  7.185408
30  4096.0  4096.0  4096.0  197.970383  6.973538
```

</details>
<details>
  <summary>Performance with XPU autotune config</summary>

```
matmul-performance-fp16:
         M       N       K     rocBLAS     Triton
0    256.0   256.0   256.0    4.559026   0.819200
1    384.0   384.0   384.0   11.603095   1.928580
2    512.0   512.0   512.0   23.301689   3.532046
3    640.0   640.0   640.0   36.817977   5.535135
4    768.0   768.0   768.0   42.256049   8.031646
5    896.0   896.0   896.0   58.386619  10.965292
6   1024.0  1024.0  1024.0   71.014673   9.199296
7   1152.0  1152.0  1152.0   85.313826  11.731306
8   1280.0  1280.0  1280.0  104.439840  14.403516
9   1408.0  1408.0  1408.0  119.083163  10.181315
10  1536.0  1536.0  1536.0  120.155126  11.952106
11  1664.0  1664.0  1664.0  142.029687  13.123627
12  1792.0  1792.0  1792.0  153.373808  15.210893
13  1920.0  1920.0  1920.0  145.276852  12.435674
14  2048.0  2048.0  2048.0  154.051912  13.632221
15  2176.0  2176.0  2176.0  150.281626  15.323182
16  2304.0  2304.0  2304.0  139.300581  13.655089
17  2432.0  2432.0  2432.0  164.205086  14.681520
18  2560.0  2560.0  2560.0  180.944949  16.230571
19  2688.0  2688.0  2688.0  168.943325  17.645846
20  2816.0  2816.0  2816.0  158.147830  15.942595
21  2944.0  2944.0  2944.0  166.727962  15.096109
22  3072.0  3072.0  3072.0  144.205279  16.244749
23  3200.0  3200.0  3200.0  172.681282  17.312651
24  3328.0  3328.0  3328.0  183.636628  16.591440
25  3456.0  3456.0  3456.0  180.664581  15.691570
26  3584.0  3584.0  3584.0  188.984735  16.546149
27  3712.0  3712.0  3712.0  189.885057  17.631699
28  3840.0  3840.0  3840.0  183.936805  17.210670
29  3968.0  3968.0  3968.0  194.242794  16.661401
30  4096.0  4096.0  4096.0  199.394961  17.179869
```

</details>

Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored May 23, 2024
1 parent a1c6a68 commit 00ac76a
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,26 @@ def is_xpu():


def get_xpu_autotune_config():
# FIXME: Add autotune config for XPU.
return get_cuda_autotune_config()
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=32),
]


def get_cuda_autotune_config():
Expand Down

0 comments on commit 00ac76a

Please sign in to comment.