Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Propagate GetResult/IterArg of Reductions out for expansion and thread shape #225

Merged
merged 3 commits into from
Oct 21, 2024

Conversation

raikonenfnu
Copy link
Contributor

In this PR, we add support to propagate indexing, expansion, and thread shape information from IterArg/induction variable of Reduction loop to it's user outside the loop.

This is important to enable for example having two induction variables in a reduction loop with different layout/index, and we'd want to do a binaryOp on them before writing them out. We'd need to figure if for example it'd need a broadcast or not. That won't be possible if we do not propagate this information from IterArg to outside/GetResult and it's use-def chain.

@raikonenfnu raikonenfnu force-pushed the propagateReductionIterArgs branch 3 times, most recently from 4547e72 to 3d6fbba Compare October 20, 2024 01:37
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few small comments, but looking good otherwise!

"""
Get root graph from some child graph inside a nested graph.
Using the assumption that any child/nested graph should have a parent_op,
who we can query for it's owner graph from to go up one level.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I understand what this function is doing but would just reword this comment a little to make it more clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg! :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description, please check again :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor typo: outer, otherwise looks good thanks!

iree/turbine/kernel/ops/wave_ops.py Show resolved Hide resolved
iree/turbine/kernel/wave/thread_shape_analysis.py Outdated Show resolved Hide resolved
elif isinstance(custom, ReduceOp):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
bwd_slice = set()
if custom.init != None and not isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check whether the init is a ReduceOp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar reason for in chained matmul, why we do not expand on acc if the acc is also a matmul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can imagine it's a reduction into another reduction, and we use the 1st reduction as an init to the second. It may try to slice/attach information all the operands of the 1st reduction, which is no good. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right so if you didn't have this check, then you would do a bwd_slice on the first reduction, which would overwrite the information from the first reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap!

@@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature(
# Set indices.
set_node_indices(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about this more, this could cause some issues with the symbol renaming I am doing. Just wondering does thread shape analysis really get easier by moving this here or can we keep it post expansion after I remove all the renaming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I am thinking currently is rename, set index, expand, set post expand index, remove renames and then rest of pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave thread shape analysis here as well, but then it needs to handle renames and essentially remap variables based on the dicts in renames.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think this warrants a lower latency communication on VC haha

iree/turbine/kernel/wave/utils.py Show resolved Hide resolved
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, lgtm!

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@raikonenfnu raikonenfnu merged commit 10cf8c7 into iree-org:main Oct 21, 2024
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants