-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GEMM-perf] matmul is slower when one input needs to be transposed #1795
Comments
@mgrabban thanks for the feedback. Could please provide information on you runtime environment:
|
I am doing this on PVC (Intel GPU Max 1550).
My Agama version is 950.4
I am using the PyTorch/IPEX installed using script inside
I am using oneAPI/2024.2.0 |
Could you please retest with the
To build Upstream PyTorch from source run the following script.
Our Tutorials code still have
|
@vlad-penkin the pytorch-ipex installation script keeps changing. Today the install itself fails. I tried
Are you able to run matmul/triton benchmarck.py from your end? |
The installation issue is now fixed but timing is now broken so triton perf time is showing as 0.0. I think this is the reason |
@Mgarban thanks for the update! See below my notes:
|
@vlad-penkin I'm now able to run and get perf data as shown below
As you can see, the issue is not resolved: inference involving |
@mgrabban , what are the sizes of Matrices you are using. I could not run triton_inf or triton_train as they were not shared. However I tried running the matmul kernel in triton tutorials with and without transposing both inputs a and b for various matrix sizes. I used this code to launch my kernel, It is just slightly modified version of your code except I do just a multiply instead of fused_mul_add
And below are my results for different matrix sizes <style> </style>
I also tried modifying the kernel in tutorial to a fused_multiply_add and still get similar numbers, I don't see a performance degradation when one of the inputs is transposed instead I still see a slight performance increase. Could you recheck if you are using latest Agama drivers and pytorch from upstream, and run the kernel in this tutorial with launch script. And let me know if the performance degradation still exists just for running matrix multiplication alone. As there might be possibly other functionalities in triton_inf or triton_train which might have a unexpected effect. These are my hw details LIBIGC1_VERSION=1.0.17193.16-950 |
@mgrabban could you provide us with the cached Triton-generated code for both runs (transpose and w/out transpose?) The easiest way to do it is to delete your Triton cache ( |
Just for reference: a single file reproducer was provided to @alexbaden |
Hi, |
Just for reference, here is a reproducer mentioned above and it's output. Output:
Reproducer:
|
Based on this feedback #2408 (review) Changed GEMM benchmark to include transposed matrices case. Closes #2424 Relates to #1795 A@B^t case is important because weight matrix is often stored in [M, K] format. For example, in https://pytorch.org/docs/stable/generated/torch.nn.Linear.html Right now we are about 1.5 times slower on XPU against raw torch for that case. A^t@B case is important because it's part of matmul backprop. Right now we are about 4 times slower on XPU against raw torch for that case.
Hello @vlad-penkin , Here is my finding:
Now (using latest in main branch), we have
So while A x B.T performance has improved (3.8X slower to now 1.7X slower vs pytorch), A.T x B performance has deteriorated (4.2X slower to now 8.9X slower vs pytorch). (I see training perf for matmul kernel in llmbench has deteriorated overall vs pytorch). What do you think? Do you plan to pursue this further? |
We cannot lower a transposed A matrix to a transposed 2D block load. Instead, the load is lowered via the LLVM path introduced in #2181 . There appears to be a performance regression in this path which is slower than materializing the block in SLM and then reading into registers and computing the dot product from there. Using the work in #2420 I am able to drop the block load attribute for this case and go down the non block ptr path. Performance on main: ``` Compute A x B ✅ Triton and Torch match Time for torch: 0.32444801926612854 ms Time for triton: 0.44371041655540466 ms Compute A x B.T ✅ Triton and Torch match Time for torch: 0.32708799839019775 ms Time for triton: 0.634996771812439 ms Compute A.T x B ✅ Triton and Torch match Time for torch: 0.31204161047935486 ms Time for triton: 3.4140689373016357 ms Compute A.T x B.T ✅ Triton and Torch match Time for torch: 0.45701122283935547 ms Time for triton: 3.7463345527648926 ms ``` Performance on this PR: ``` Compute A x B ✅ Triton and Torch match Time for torch: 0.3081200122833252 ms Time for triton: 0.44333598017692566 ms Compute A x B.T ✅ Triton and Torch match Time for torch: 0.33799198269844055 ms Time for triton: 0.6391856074333191 ms Compute A.T x B ✅ Triton and Torch match Time for torch: 0.31700319051742554 ms Time for triton: 1.5733630657196045 ms Compute A.T x B.T ✅ Triton and Torch match Time for torch: 0.45083683729171753 ms Time for triton: 1.8271965980529785 ms ``` Note that the important commit is `31386ef1132c3f6cf9cb5f1063ecfab705f4c2a1`. Once #2420 is merged I will rebase this. Depends on #2420. Links to #1795.
I find that
matmul(X, Y)
is ~4X slower when either X or Y needs to be transposed.So I have a matmul kernel that is similar to the one in triton tutorial here.
That kernel is launched from this code
Note that the strides of X or Y are switched (e.g.
Xstride0, Xstride1 = X.stride(1), X.stride(0)
) if it needs to be transposed.I notice ff neither needs to be transposed, performance is similar to PyTorch's matmul perf but when either needs to be transposed (so that strides are switched for that input), performance is 4X slower.
This does not happen on CUDA devices. So can you please look into making it efficient for XPU devices as well?
The text was updated successfully, but these errors were encountered: