Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Propagate GetResult/IterArg of Reductions out for expansion and thread shape #225

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,20 @@ def wrapper(f):

return wrapper

def get_root_graph(self):
"""
Return the "root"/most outter layer of our computation graph.
This is done by iteratively accessing parent_graph of current
graph. This is done until we find the "root" graph who
will have "subgraph" attribute.
"""
cur_graph = self.graph
while not hasattr(cur_graph, "subgraphs"):
if not hasattr(cur_graph, "parent_op"):
raise ValueError("All subgraphs should have parent_op")
cur_graph = cur_graph.parent_op.graph
return cur_graph

@property
def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]:
expand_dims: list[IndexSymbol] = []
Expand Down Expand Up @@ -1014,15 +1028,18 @@ def outputs(self, graph: fx.Graph) -> list[fx.Node]:

@property
def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
if not hasattr(self.graph, "subgraphs"):
return None
for node in self.graph.subgraphs[self.subgraph_name].nodes:
for node in self.get_root_graph().subgraphs[self.subgraph_name].nodes:
if isinstance(output := get_custom(node), Output):
return_vals = output.return_vals[0]
return (
[val.index for val in return_vals]
if isinstance(return_vals, list)
else None
[
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
for val in return_vals
]
if isinstance(return_vals, (Sequence))
else return_vals.index
)

@index.setter
Expand Down Expand Up @@ -1111,11 +1128,12 @@ def indexing_dims(self) -> list[IndexExpr]:
@property
def index(self) -> dict[IndexSymbol, IndexSequence]:
custom = get_custom(self.value)
if custom.index is None:
custom_index = custom.index
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
if custom_index is None:
return None
if not isinstance(custom, Reduction):
return custom.index
assert isinstance(custom.index, list) and self.res_idx < len(
assert isinstance(custom_index, Sequence) and self.res_idx < len(
custom.indexing_dims
)
return custom.index[self.res_idx]
Expand Down
37 changes: 32 additions & 5 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):

anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
legalSubtypes = (IterArg,)
nonPropagatableTypes = anchorOpTypes + noHandleTypes


Expand All @@ -61,17 +62,27 @@ def is_anchor_op(node: fx.Node):


def propagatable_op(node: fx.Node):
return not isinstance(get_custom(node), nonPropagatableTypes)
custom_node = get_custom(node)
return not isinstance(custom_node, nonPropagatableTypes) or isinstance(
custom_node, legalSubtypes
)


def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]:
"""
This function will attempt to resolve binaryOp conflicts
by inserting broadcastOp. It will then propagate the resolutions,
and return the list of fx.Nodes that we have resolved.
"""

def handle_binaryop_conflict(custom_node: CustomOp):
# Analyze if we can resolve conflict with broadcast.
lhs = get_custom(custom_node.lhs)
rhs = get_custom(custom_node.rhs)
lhs_dim_set = set(lhs.type.symbolic_shape)
rhs_dim_set = set(rhs.type.symbolic_shape)
if lhs_dim_set == rhs_dim_set:
raise ValueError("Cannot broadcast if lhs and rhs is already same.")
# Could be caused by consumers(likely also binaryOp) of this node.
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
return []
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.")
# Determine the correct indexSize for binaryOp and insert broadcasting.
Expand All @@ -84,7 +95,8 @@ def handle_binaryop_conflict(custom_node: CustomOp):
propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op)
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
return propagated_resolutions
resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op)
return propagated_resolutions.union(resolved_resolutions)


# Returns True iff all conflicts are handled succesfully.
Expand Down Expand Up @@ -155,11 +167,26 @@ def determine_thread_shapes(trace: CapturedTrace):
for anchor_op in anchor_ops:
custom = get_custom(anchor_op)
index_sizes = get_custom_dim_sizes(custom)
if isinstance(custom, (Read, ReduceOp)):
if isinstance(custom, Read):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(fwd_slice)
elif isinstance(custom, ReduceOp):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
bwd_slice = set()
if custom.init != None and not isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check whether the init is a ReduceOp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar reason for in chained matmul, why we do not expand on acc if the acc is also a matmul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can imagine it's a reduction into another reduction, and we use the 1st reduction as an init to the second. It may try to slice/attach information all the operands of the 1st reduction, which is no good. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right so if you didn't have this check, then you would do a bwd_slice on the first reduction, which would overwrite the information from the first reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap!

get_custom(custom.init), ReduceOp
):
bwd_slice = capture_backward_slice(custom.init, propagatable_op)
reduce_dims = frozenset(
[DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim]
)
thread_size_to_ops[reduce_dims] = (
thread_size_to_ops.get(reduce_dims, set([]))
.union(fwd_slice)
.union(bwd_slice)
)
elif isinstance(custom, Write):
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
Expand Down
23 changes: 14 additions & 9 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,18 +509,23 @@ def get_users(
init_arg_idx = custom.init_args.index(node)
users.append(custom.iter_args[init_arg_idx])
continue
if isinstance(custom, Output) and reduction:
if isinstance(custom, Output):
# Map output to get result
return_vals = custom.return_vals[0]
get_results = sorted(
[x for x in reduction.users if isinstance(get_custom(x), GetResult)],
lambda x: get_custom(x).res_idx,
)
if isinstance(return_vals, list):
output_idx = return_vals.index(node)
users.append(get_results[output_idx])
parent_reduction = custom.graph.parent_op
if not isinstance(return_vals, (list, tuple)):
users.append(next(iter(parent_reduction.users)))
else:
users.append(get_results[0])
# Handles case where DCE eliminate unused GetResult.
get_results = {
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
get_custom(x).res_idx: x
for x in parent_reduction.users
if isinstance(get_custom(x), GetResult)
}
output_idx = return_vals.index(node)
# Sometime IterArg only used within the tkw.Reduction region
if output_idx in get_results:
users.append(get_results[output_idx])
continue
users.append(user)
return users, reduction
Expand Down
6 changes: 3 additions & 3 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature(
# Set indices.
set_node_indices(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about this more, this could cause some issues with the symbol renaming I am doing. Just wondering does thread shape analysis really get easier by moving this here or can we keep it post expansion after I remove all the renaming?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I am thinking currently is rename, set index, expand, set post expand index, remove renames and then rest of pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave thread shape analysis here as well, but then it needs to handle renames and essentially remap variables based on the dicts in renames.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think this warrants a lower latency communication on VC haha


# Expansion
expand_graph(graph, self.constraints)

Expand All @@ -249,9 +252,6 @@ def _trace_and_get_kernel_signature(
# Partition strided operators.
partition_strided_operators(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
Expand Down
30 changes: 18 additions & 12 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-8: amdgpu.mfma


# This test is used to check two things
# This test is used to check three things
# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works
# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape).
# 3. We can propagate layout of multiple Reduction results through IterArg/GetResult
# and observe that broadcast is being generated to resolve binaryOp.
@run_test
def test_gemm_and_reduce():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand All @@ -987,8 +989,7 @@ def test_gemm_and_reduce():
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.f16],
d: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)
init_max = tkl.Register[M, tkl.f16](-1e6)
Expand All @@ -1004,8 +1005,8 @@ def repeat(
return partial_max, acc

res_max, res_mm = repeat
tkw.write(res_max, c, elements_per_thread=1)
tkw.write(res_mm, d, elements_per_thread=STORE_ELEMS_PER_THREAD)
res = res_mm / tkw.cast(res_max, tkl.f32)
tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
Expand All @@ -1024,22 +1025,25 @@ def repeat(
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, dtype=torch.float16)
d = torch.zeros(64, 128, dtype=torch.float32)
print(gemm(a, b, c, d).module_op)
c = torch.zeros(64, 128, dtype=torch.float32)
print(gemm(a, b, c).module_op)
# CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index

# Tile Reduction Loop
# Note: Shape is 32x20 instead of 32x16 because of padding to avoid bank conflicts
# CHECK: %{{.*}}:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}})
# CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-6: gpu.shuffle xor
# CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}}
# CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]]
# CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32>
# CHECK: %[[MAX_EXT:.+]] = arith.extf %[[LOOP]]#0 : vector<1xf16> to vector<1xf32>
# CHECK: %[[BCAST_SRC:.+]] = vector.extract %[[MAX_EXT]][0] : f32 from vector<1xf32>
# CHECK: %[[BROADCAST:.+]] = vector.splat %19 : vector<4xf32>
# CHECK: arith.divf %[[LOOP]]#1, %[[BROADCAST]] : vector<4xf32>


@run_test
Expand Down Expand Up @@ -1757,14 +1761,16 @@ def test(
# CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>
# CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>

# 1st Broadcast-ADD RHS
# 1st Broadcast RHS
# CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>

# 2nd Broadcast-ADD RHS
# 2nd Broadcast RHS
# CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16>

# Broadcast-ADD RHS
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16>


Expand Down
18 changes: 9 additions & 9 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -452,13 +452,13 @@ def test_batched_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0]
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: output(return_vals=(None,))

# Reduction subgraph:
Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_1_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: write(register_=getresult_1_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_0_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})

# Reduction subgraph:
# CHECK: %acc_0_0_0
Expand Down
Loading