-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Propagate GetResult/IterArg of Reductions out for expansion and thread shape #225
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,8 +51,9 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): | |
# Anchor Indicies and Conflict resolution helpers | ||
################################################################# | ||
|
||
anchorOpTypes = (Read, Write, MMA, ReduceOp) | ||
anchorOpTypes = (Read, Write, MMA, ReduceOp, GetResult) | ||
raikonenfnu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) | ||
legalSubtypes = (IterArg,) | ||
nonPropagatableTypes = anchorOpTypes + noHandleTypes | ||
|
||
|
||
|
@@ -61,7 +62,10 @@ def is_anchor_op(node: fx.Node): | |
|
||
|
||
def propagatable_op(node: fx.Node): | ||
return not isinstance(get_custom(node), nonPropagatableTypes) | ||
custom_node = get_custom(node) | ||
return not isinstance(custom_node, nonPropagatableTypes) or isinstance( | ||
custom_node, legalSubtypes | ||
) | ||
|
||
|
||
def handle_binaryop_conflict(custom_node: CustomOp): | ||
|
@@ -71,7 +75,8 @@ def handle_binaryop_conflict(custom_node: CustomOp): | |
lhs_dim_set = set(lhs.type.symbolic_shape) | ||
rhs_dim_set = set(rhs.type.symbolic_shape) | ||
if lhs_dim_set == rhs_dim_set: | ||
raise ValueError("Cannot broadcast if lhs and rhs is already same.") | ||
# Could be caused by consumers(likely also binaryOp) of this node. | ||
harsh-nod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return [] | ||
if lhs_dim_set.isdisjoint(rhs_dim_set): | ||
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.") | ||
# Determine the correct indexSize for binaryOp and insert broadcasting. | ||
|
@@ -84,7 +89,8 @@ def handle_binaryop_conflict(custom_node: CustomOp): | |
propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op) | ||
for node in propagated_resolutions: | ||
get_custom(node).index = dst_op.index | ||
return propagated_resolutions | ||
resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op) | ||
return propagated_resolutions.union(resolved_resolutions) | ||
|
||
|
||
# Returns True iff all conflicts are handled succesfully. | ||
|
@@ -155,11 +161,26 @@ def determine_thread_shapes(trace: CapturedTrace): | |
for anchor_op in anchor_ops: | ||
custom = get_custom(anchor_op) | ||
index_sizes = get_custom_dim_sizes(custom) | ||
if isinstance(custom, (Read, ReduceOp)): | ||
if isinstance(custom, (Read, GetResult)): | ||
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) | ||
thread_size_to_ops[index_sizes] = thread_size_to_ops.get( | ||
index_sizes, set([]) | ||
).union(fwd_slice) | ||
elif isinstance(custom, ReduceOp): | ||
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) | ||
bwd_slice = set() | ||
if custom.init != None and not isinstance( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to check whether the init is a ReduceOp? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar reason for in chained matmul, why we do not expand on acc if the acc is also a matmul There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can imagine it's a reduction into another reduction, and we use the 1st reduction as an init to the second. It may try to slice/attach information all the operands of the 1st reduction, which is no good. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right so if you didn't have this check, then you would do a bwd_slice on the first reduction, which would overwrite the information from the first reduction. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeap! |
||
get_custom(custom.init), ReduceOp | ||
): | ||
bwd_slice = capture_backward_slice(custom.init, propagatable_op) | ||
reduce_dims = frozenset( | ||
[DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim] | ||
) | ||
thread_size_to_ops[reduce_dims] = ( | ||
thread_size_to_ops.get(reduce_dims, set([])) | ||
.union(fwd_slice) | ||
.union(bwd_slice) | ||
) | ||
elif isinstance(custom, Write): | ||
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) | ||
thread_size_to_ops[index_sizes] = thread_size_to_ops.get( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature( | |
# Set indices. | ||
set_node_indices(graph, self.constraints) | ||
|
||
# Analyze Thread Shapes per Op. | ||
determine_thread_shapes(graph) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually thinking about this more, this could cause some issues with the symbol renaming I am doing. Just wondering does thread shape analysis really get easier by moving this here or can we keep it post expansion after I remove all the renaming? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I am thinking currently is rename, set index, expand, set post expand index, remove renames and then rest of pipeline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can leave thread shape analysis here as well, but then it needs to handle renames and essentially remap variables based on the dicts in renames. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I think this warrants a lower latency communication on VC haha |
||
|
||
# Expansion | ||
expand_graph(graph, self.constraints) | ||
|
||
|
@@ -249,9 +252,6 @@ def _trace_and_get_kernel_signature( | |
# Partition strided operators. | ||
partition_strided_operators(graph, self.constraints) | ||
|
||
# Analyze Thread Shapes per Op. | ||
determine_thread_shapes(graph) | ||
|
||
# Align sizes to WG/Tile sizes | ||
# This pass changes indexing keys, which can interfere with other passes, | ||
# so it should be called close to the end of pipeline. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I understand what this function is doing but would just reword this comment a little to make it more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the description, please check again :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor typo: outer, otherwise looks good thanks!