You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Our current index_seq_analysis, does a backward pass on lhs, rhs, and acc, and then does a forward pass on it's consumers. This is working out OK for now, however for more complex cases we may need to modify it to also do detours. Consider this case:
In the case above, if we want read of bias to also have the layouts from mma's acc, we'd need to do a detour during layout setting i.e mma -> res -(detour)-> bias read.
This is actually also evident in the case of our attention kernel. Currently, we are manually setting vector_shapes for M and N on our attention kernel
Our current index_seq_analysis, does a backward pass on lhs, rhs, and acc, and then does a forward pass on it's consumers. This is working out OK for now, however for more complex cases we may need to modify it to also do detours. Consider this case:
In the case above, if we want
read
ofbias
to also have the layouts from mma's acc, we'd need to do a detour during layout setting i.emma -> res -(detour)-> bias read.
This is actually also evident in the case of our attention kernel. Currently, we are manually setting vector_shapes for
M
andN
on our attention kerneliree-turbine/tests/kernel/wave/wave_attention_test.py
Line 244 in 2b45c0f
The text was updated successfully, but these errors were encountered: