Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPAS]: Use 2d-loads instruction to load the operand of tt.dot #146

Closed
etiotto opened this issue Dec 18, 2023 · 5 comments · Fixed by #959 or #958
Closed

[DPAS]: Use 2d-loads instruction to load the operand of tt.dot #146

etiotto opened this issue Dec 18, 2023 · 5 comments · Fixed by #959 or #958

Comments

@etiotto
Copy link
Contributor

etiotto commented Dec 18, 2023

The operands of the Triton's tt.dot operation should be loaded by using specialized instruction to load 2D blocks of the matrices.
Loading the operands in blocks is more efficient than loading them by using regular loads @llvm.genx.GenISA.LSCPrefetch.

We might need to leverage the semantic information associated with Tritons blocked pointers (https://triton-lang.org/main/getting-started/tutorials/08-experimental-block-pointer.html) in order to generate 2d-Blocked loads.

@tdeng5
Copy link

tdeng5 commented Jan 23, 2024

Liyang, please check how cuda handle stride_xx.

@LiyangLingIntel
Copy link
Contributor

Liyang, please check how cuda handle stride_xx.

Triton CUDA pipeline lower 2d-load to TMALoadTiledOp which does not have limitation on the last dim strides.
For Triton XPU pipeline, if we want to leverage GenISA_LSC2DBlockRead, it requires last dim (stride[-1]=1) continuous for each block pointer.

For the first stage, in the pass of 2d load conversion lowering, we will check the stride attr type. If the type is a constant, and meets the 2d block load case, we can leverage genx.matrix.2Dblockload. Otherwise, it will fallback to regular loads. This is the fastest way to enable 2DblockLoad in our pipeline with limited functionalities.

For the second stage, we will consider the dynamic stride case. If the stride attr type is a variable, the lowering strategy is to using a conditional branching to decide the if the last dim stride is 1 at kernel runtime. Then pick the block loads or regular loads.

@vlad-penkin
Copy link
Contributor

@LiyangLingIntel as per our discussion could you please split this ticket by stage

@LiyangLingIntel
Copy link
Contributor

@LiyangLingIntel as per our discussion could you please split this ticket by stage

Sure, I have add 2 issues(#413 and #415) to split this ticket as 2 stages.

@etiotto etiotto self-assigned this Apr 29, 2024
@etiotto
Copy link
Contributor Author

etiotto commented Apr 29, 2024

Helping with refactoring and code review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment