-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Teach expansion to handle non direct acc and ReduceOp on reduction dim. #243
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good! Thanks I like the changes to expansion. Just some comments about documentation .
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
d9e0b24
to
c7f5ac3
Compare
fixes on the equal is required sometimes it's not equal if we do not do this manual check of shape and type Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good and we can address the open questions later. Thanks!
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
In flash attention, we need to enable non direct acc matmul, and also expansion of reduceOp in reduction dimension. The former is needed in FA since we are applying some scaling to the acc of second MMA before feeding it in. The second case is required in FA because ReduceOp/MaxOp is in the backward slice of second MMA's LHS, which would require it to be expanded in K2/reduction dim as well.