-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TKW] Add support for multiple/local reduceOp (#234)
In order to support flash attention, we'd need to be able to expand ReduceOps in the reduction dimension as well. We will do this by expanding the source of ReduceOp and locally reduce all of them. In that effort, we introduce this PR(1st out of 2) that add support of locally reducing over multiple variables. The second PR on the way would be expansion of ReduceOp. In this PR we are contributing two things: 1. Checks for consistency of indexing_dims, types, thread_shapes for multiple sources of ReduceOp 2. Modify emit of local reduction to generate iteratively slice and reduce over multiple arguments/srcs. --------- Signed-off-by: Stanley Winata <stanley.winata@amd.com>
- Loading branch information
1 parent
00dcee7
commit 50e17a5
Showing
4 changed files
with
130 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters