-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial example 03 performance issue #1122
Comments
@narendrachaudhary51 could you please provide the information on the environment
|
I was able to run the following command |
@narendrachaudhary51 I'm playing with this exemple too and noticed the performance issue too. I think the grid search parameters have not been tuned. The grid search parameters currently used are the same than for cuda, but I noticed that using higher
with those parameters, the 512 * 512 matmul performance is very close to |
Here are the changes to the example that I find gives better performance: #1139 It also improves performance for higher dimensions with a 3 to 4 times speedup, but not to the point of reaching
Maybe more work on the grid search could help having a good speedup on higher dimensions too. |
@fcharras Thank you for your reply. I suspected that grid search parameters could be the cause. |
I have only some intuition and a limited understanding of the xpu concepts to which warps (which is a cuda concept only) could be mapped to (execution units ? subslices ?), but I suspected that "num_warps" value in the grid is too low in this regard. But to be honest I'm surprised that it's good for cuda device to set it so low to begin with. So it was more some luck... |
From (again) entry-level understanding of it, matmul performance is a balance that is hard to achieve, between global and local memory bandwidth, cache bandwidth and hit rate, and actual compute, and I thought that if something has to be a bottleneck there, it might be the compute because of a too low number of threads being leveraged, which is (I think) increased when increasing |
By default the Triton kernel in the tutorial compiles the matrix multiplication operation to a sequence of floating point multiply-add scalar instructions. We can force the tutorial to use 16 threads per warp which allows our compiler to generate specialized HW instructions (DPAS) rather than scalar FMAs. Performance then improves ~7X for large problem sizes. The code to change in the tutorial is:
Performance is still lacking behind the I will post a PR to change the warp size used by the tutorial. |
Thank @etiotto , which this tip, and also suggestion in #1139 (comment) , I'm seeing almost equivalent walltime on the 512 * 512 example, and it has increased a lot on larger (4096 * 4096) example, albeit still below
maybe the matmul with the experimental block pointer approach (tutorial 09) will give better results ? |
That's an interesting question. Note however that the experimental example, including tutorial 09 on the use of block pointers for matrix matrix multiplication have been removed from the upstream repo (triton-lang/triton#3371), but I am not sure why. Still I would be interested in the performance results of a Max Series GPU with a tuned grid and optimal |
I tried the change suggested by @etiotto and increased the number of warps. This gave me speedup across the board. I tried several other configurations, but I was not able to go beyond this. Is there a way to check and analyze the triton generated code? We can then identify the inefficiencies in the code by comparing it to torch implementation. |
The rocBLAS legend actually refers to Intel oneDNN's kernel wrapped as pytorch XPU matmul, right? |
@ogrisel Yes. It is the pytorch XPU matmul. It must be using the oneDNN kernel underneath. |
Is the triton generated code using 2D loads when doing the following operations? a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) How can I dump the generated code? |
Output of different stages can be found in |
@ogrisel Block pointer-based matrix multiplication is faster than the previous implementation which only reached peak performance of 25 TFlops. But current performance is still below oneDNN matrix multiplication. |
@AshburnLee could please retest the issue and report back the results with the top of the llvm-target branch with the
|
I'm working on it. |
@vlad-penkin I get the following output with latest nightly build of pytorch. I comment out the intel_extension_for_pytorch. WARNING:root:Wall time is used instead of elapsed_time (not supported). The timing measurements could be innacurate. Segmentation fault (core dumped) |
Same issue from my local. |
After rebase and another try, here is the result: Note:
|
I just expanded
@narendrachaudhary51 can we close this issue ? |
This is good. However, performance still seems lower compared to oneDNN. I will try to replicate on my side. Do I still need to patch the upstream pytorch and compile it? or is it possible to use the latest nightly pytorch? |
For the time being, please use our PyTorch nightly wheels, which are built with 'XPUEvent elapsed_time(...)' support. Our 'XPUEvent elapsed_time(...)' patch, inherited from IPEX, works well with Triton but causes significant performance degradation in PT use cases. We expect that the nightly PyTorch wheels will have 'XPUEvent elapsed_time(...)' enabled sometime in November 2024, following the release of oneAPI 2025.0.0. |
I see the following performance with tutorial 03-matrix-multiplication.py. ✅ Triton and Torch match And with tutorial 10-experimental-block-pointer.py matmul-performance-fp16: My configurations are as following Device Driver - OneAPI - 2024.2.1 @etiotto I can see that triton is reaching close to the performance of oneDNN. However, it seems both oneDNN and triton numbers are slower compared to yours. |
The numbers reported in #1122 (comment) are low for a PVC 1550. I just tried tutorial 10 again on a PVC 1100 and the throughput for both Triton and oneDNN is higher: LIBIGC1_VERSION=1.0.17193.16-950
|
Hi,
I am trying to run the benchmark in python tutorial 03-matrix-multiplication.py.
I do not see the expected performance with triton. Even for higher square matrix sizes the performance of triton is not improving.
I am using the following XPU hardware.
Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100 1.3 [1.3.26918]
matmul-performance-fp16:
M N K rocBLAS Triton
0 256.0 256.0 256.0 0.108552 0.065673
1 384.0 384.0 384.0 0.363976 0.154921
2 512.0 512.0 512.0 0.857502 0.162093
3 640.0 640.0 640.0 1.664666 0.174643
4 768.0 768.0 768.0 2.829421 0.307113
5 896.0 896.0 896.0 4.453225 0.395680
6 1024.0 1024.0 1024.0 6.550690 0.316575
7 1152.0 1152.0 1152.0 9.213149 0.378349
8 1280.0 1280.0 1280.0 12.485583 0.444944
9 1408.0 1408.0 1408.0 15.589347 0.518402
10 1536.0 1536.0 1536.0 19.637789 0.576684
11 1664.0 1664.0 1664.0 23.102231 0.644403
12 1792.0 1792.0 1792.0 29.274080 0.726035
13 1920.0 1920.0 1920.0 32.775947 0.721385
14 2048.0 2048.0 2048.0 36.019792 0.722632
15 2176.0 2176.0 2176.0 43.476061 0.755484
16 2304.0 2304.0 2304.0 50.967526 0.797482
17 2432.0 2432.0 2432.0 59.675966 0.516829
18 2560.0 2560.0 2560.0 65.795927 0.524481
19 2688.0 2688.0 2688.0 72.057158 0.525320
20 2816.0 2816.0 2816.0 75.685494 0.397117
21 2944.0 2944.0 2944.0 78.605996 0.397091
22 3072.0 3072.0 3072.0 77.401139 0.392022
23 3200.0 3200.0 3200.0 93.879067 0.395438
24 3328.0 3328.0 3328.0 100.308276 0.366734
25 3456.0 3456.0 3456.0 107.904954 0.323059
26 3584.0 3584.0 3584.0 116.285356 0.305686
27 3712.0 3712.0 3712.0 97.801647 0.296061
28 3840.0 3840.0 3840.0 100.229800 0.293752
29 3968.0 3968.0 3968.0 103.656807 0.278301
30 4096.0 4096.0 4096.0 105.869743 0.265378
The text was updated successfully, but these errors were encountered: