-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Propagate GetResult/IterArg of Reductions out for expansion and thread shape #225
[TKW] Propagate GetResult/IterArg of Reductions out for expansion and thread shape #225
Conversation
4547e72
to
3d6fbba
Compare
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few small comments, but looking good otherwise!
iree/turbine/kernel/ops/wave_ops.py
Outdated
""" | ||
Get root graph from some child graph inside a nested graph. | ||
Using the assumption that any child/nested graph should have a parent_op, | ||
who we can query for it's owner graph from to go up one level. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I understand what this function is doing but would just reword this comment a little to make it more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the description, please check again :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor typo: outer, otherwise looks good thanks!
elif isinstance(custom, ReduceOp): | ||
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) | ||
bwd_slice = set() | ||
if custom.init != None and not isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to check whether the init is a ReduceOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar reason for in chained matmul, why we do not expand on acc if the acc is also a matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can imagine it's a reduction into another reduction, and we use the 1st reduction as an init to the second. It may try to slice/attach information all the operands of the 1st reduction, which is no good. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right so if you didn't have this check, then you would do a bwd_slice on the first reduction, which would overwrite the information from the first reduction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap!
@@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature( | |||
# Set indices. | |||
set_node_indices(graph, self.constraints) | |||
|
|||
# Analyze Thread Shapes per Op. | |||
determine_thread_shapes(graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thinking about this more, this could cause some issues with the symbol renaming I am doing. Just wondering does thread shape analysis really get easier by moving this here or can we keep it post expansion after I remove all the renaming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I am thinking currently is rename, set index, expand, set post expand index, remove renames and then rest of pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave thread shape analysis here as well, but then it needs to handle renames and essentially remap variables based on the dicts in renames.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think this warrants a lower latency communication on VC haha
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, lgtm!
In this PR, we add support to propagate indexing, expansion, and thread shape information from IterArg/induction variable of Reduction loop to it's user outside the loop.
This is important to enable for example having two induction variables in a reduction loop with different layout/index, and we'd want to do a binaryOp on them before writing them out. We'd need to figure if for example it'd need a broadcast or not. That won't be possible if we do not propagate this information from IterArg to outside/GetResult and it's use-def chain.