Skip to content

Commit

Permalink
[TKW] Propagate GetResult/IterArg of Reductions out for expansion and…
Browse files Browse the repository at this point in the history
… thread shape (#225)

In this PR, we add support to propagate indexing, expansion, and thread
shape information from IterArg/induction variable of Reduction loop to
it's user outside the loop.

This is important to enable for example having two induction variables
in a reduction loop with different layout/index, and we'd want to do a
binaryOp on them before writing them out. We'd need to figure if for
example it'd need a broadcast or not. That won't be possible if we do
not propagate this information from IterArg to outside/GetResult and
it's use-def chain.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Oct 21, 2024
1 parent 97e0517 commit 10cf8c7
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 50 deletions.
34 changes: 26 additions & 8 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,20 @@ def wrapper(f):

return wrapper

def get_root_graph(self):
"""
Return the "root"/outermost layer of our computation graph.
This is done by iteratively accessing parent_graph of current
graph. This is done until we find the "root" graph who
will have "subgraph" attribute.
"""
cur_graph = self.graph
while not hasattr(cur_graph, "subgraphs"):
if not hasattr(cur_graph, "parent_op"):
raise ValueError("All subgraphs should have parent_op")
cur_graph = cur_graph.parent_op.graph
return cur_graph

@property
def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]:
expand_dims: list[IndexSymbol] = []
Expand Down Expand Up @@ -1014,15 +1028,18 @@ def outputs(self, graph: fx.Graph) -> list[fx.Node]:

@property
def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
if not hasattr(self.graph, "subgraphs"):
return None
for node in self.graph.subgraphs[self.subgraph_name].nodes:
for node in self.get_root_graph().subgraphs[self.subgraph_name].nodes:
if isinstance(output := get_custom(node), Output):
return_vals = output.return_vals[0]
return (
[val.index for val in return_vals]
if isinstance(return_vals, list)
else None
[
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
for val in return_vals
]
if isinstance(return_vals, (Sequence))
else return_vals.index
)

@index.setter
Expand Down Expand Up @@ -1111,11 +1128,12 @@ def indexing_dims(self) -> list[IndexExpr]:
@property
def index(self) -> dict[IndexSymbol, IndexSequence]:
custom = get_custom(self.value)
if custom.index is None:
custom_index = custom.index
if custom_index is None:
return None
if not isinstance(custom, Reduction):
return custom.index
assert isinstance(custom.index, list) and self.res_idx < len(
assert isinstance(custom_index, Sequence) and self.res_idx < len(
custom.indexing_dims
)
return custom.index[self.res_idx]
Expand Down
37 changes: 32 additions & 5 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):

anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
legalSubtypes = (IterArg,)
nonPropagatableTypes = anchorOpTypes + noHandleTypes


Expand All @@ -61,17 +62,27 @@ def is_anchor_op(node: fx.Node):


def propagatable_op(node: fx.Node):
return not isinstance(get_custom(node), nonPropagatableTypes)
custom_node = get_custom(node)
return not isinstance(custom_node, nonPropagatableTypes) or isinstance(
custom_node, legalSubtypes
)


def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
"""
This function will attempt to resolve binaryOp conflicts
by inserting broadcastOp. It will then propagate the resolutions,
and return the list of fx.Nodes that we have resolved.
"""

def handle_binaryop_conflict(custom_node: CustomOp):
# Analyze if we can resolve conflict with broadcast.
lhs = get_custom(custom_node.lhs)
rhs = get_custom(custom_node.rhs)
lhs_dim_set = set(lhs.type.symbolic_shape)
rhs_dim_set = set(rhs.type.symbolic_shape)
if lhs_dim_set == rhs_dim_set:
raise ValueError("Cannot broadcast if lhs and rhs is already same.")
# Could be caused by consumers(likely also binaryOp) of this node.
return []
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.")
# Determine the correct indexSize for binaryOp and insert broadcasting.
Expand All @@ -84,7 +95,8 @@ def handle_binaryop_conflict(custom_node: CustomOp):
propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op)
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
return propagated_resolutions
resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op)
return propagated_resolutions.union(resolved_resolutions)


# Returns True iff all conflicts are handled succesfully.
Expand Down Expand Up @@ -155,11 +167,26 @@ def determine_thread_shapes(trace: CapturedTrace):
for anchor_op in anchor_ops:
custom = get_custom(anchor_op)
index_sizes = get_custom_dim_sizes(custom)
if isinstance(custom, (Read, ReduceOp)):
if isinstance(custom, Read):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(fwd_slice)
elif isinstance(custom, ReduceOp):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
bwd_slice = set()
if custom.init != None and not isinstance(
get_custom(custom.init), ReduceOp
):
bwd_slice = capture_backward_slice(custom.init, propagatable_op)
reduce_dims = frozenset(
[DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim]
)
thread_size_to_ops[reduce_dims] = (
thread_size_to_ops.get(reduce_dims, set([]))
.union(fwd_slice)
.union(bwd_slice)
)
elif isinstance(custom, Write):
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
Expand Down
23 changes: 14 additions & 9 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,18 +509,23 @@ def get_users(
init_arg_idx = custom.init_args.index(node)
users.append(custom.iter_args[init_arg_idx])
continue
if isinstance(custom, Output) and reduction:
if isinstance(custom, Output):
# Map output to get result
return_vals = custom.return_vals[0]
get_results = sorted(
[x for x in reduction.users if isinstance(get_custom(x), GetResult)],
lambda x: get_custom(x).res_idx,
)
if isinstance(return_vals, list):
output_idx = return_vals.index(node)
users.append(get_results[output_idx])
parent_reduction = custom.graph.parent_op
if not isinstance(return_vals, (list, tuple)):
users.append(next(iter(parent_reduction.users)))
else:
users.append(get_results[0])
# Handles case where DCE eliminate unused GetResult.
get_results = {
get_custom(x).res_idx: x
for x in parent_reduction.users
if isinstance(get_custom(x), GetResult)
}
output_idx = return_vals.index(node)
# Sometime IterArg only used within the tkw.Reduction region
if output_idx in get_results:
users.append(get_results[output_idx])
continue
users.append(user)
return users, reduction
Expand Down
6 changes: 3 additions & 3 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature(
# Set indices.
set_node_indices(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)

# Expansion
expand_graph(graph, self.constraints)

Expand All @@ -249,9 +252,6 @@ def _trace_and_get_kernel_signature(
# Partition strided operators.
partition_strided_operators(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
Expand Down
30 changes: 18 additions & 12 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-8: amdgpu.mfma


# This test is used to check two things
# This test is used to check three things
# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works
# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape).
# 3. We can propagate layout of multiple Reduction results through IterArg/GetResult
# and observe that broadcast is being generated to resolve binaryOp.
@run_test
def test_gemm_and_reduce():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand All @@ -987,8 +989,7 @@ def test_gemm_and_reduce():
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.f16],
d: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)
init_max = tkl.Register[M, tkl.f16](-1e6)
Expand All @@ -1004,8 +1005,8 @@ def repeat(
return partial_max, acc

res_max, res_mm = repeat
tkw.write(res_max, c, elements_per_thread=1)
tkw.write(res_mm, d, elements_per_thread=STORE_ELEMS_PER_THREAD)
res = res_mm / tkw.cast(res_max, tkl.f32)
tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
Expand All @@ -1024,22 +1025,25 @@ def repeat(
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, dtype=torch.float16)
d = torch.zeros(64, 128, dtype=torch.float32)
print(gemm(a, b, c, d).module_op)
c = torch.zeros(64, 128, dtype=torch.float32)
print(gemm(a, b, c).module_op)
# CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index

# Tile Reduction Loop
# Note: Shape is 32x20 instead of 32x16 because of padding to avoid bank conflicts
# CHECK: %{{.*}}:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}})
# CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-6: gpu.shuffle xor
# CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}}
# CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]]
# CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32>
# CHECK: %[[MAX_EXT:.+]] = arith.extf %[[LOOP]]#0 : vector<1xf16> to vector<1xf32>
# CHECK: %[[BCAST_SRC:.+]] = vector.extract %[[MAX_EXT]][0] : f32 from vector<1xf32>
# CHECK: %[[BROADCAST:.+]] = vector.splat %19 : vector<4xf32>
# CHECK: arith.divf %[[LOOP]]#1, %[[BROADCAST]] : vector<4xf32>


@run_test
Expand Down Expand Up @@ -1757,14 +1761,16 @@ def test(
# CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>
# CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>

# 1st Broadcast-ADD RHS
# 1st Broadcast RHS
# CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>

# 2nd Broadcast-ADD RHS
# 2nd Broadcast RHS
# CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16>

# Broadcast-ADD RHS
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16>


Expand Down
18 changes: 9 additions & 9 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -452,13 +452,13 @@ def test_batched_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0]
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: output(return_vals=(None,))

# Reduction subgraph:
Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_1_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: write(register_=getresult_1_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_0_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})

# Reduction subgraph:
# CHECK: %acc_0_0_0
Expand Down

0 comments on commit 10cf8c7

Please sign in to comment.