diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 17a6a3fa..19e3f64c 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -961,6 +961,20 @@ def wrapper(f): return wrapper + def get_root_graph(self): + """ + Return the "root"/outermost layer of our computation graph. + This is done by iteratively accessing parent_graph of current + graph. This is done until we find the "root" graph who + will have "subgraph" attribute. + """ + cur_graph = self.graph + while not hasattr(cur_graph, "subgraphs"): + if not hasattr(cur_graph, "parent_op"): + raise ValueError("All subgraphs should have parent_op") + cur_graph = cur_graph.parent_op.graph + return cur_graph + @property def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] @@ -1014,15 +1028,18 @@ def outputs(self, graph: fx.Graph) -> list[fx.Node]: @property def index(self) -> list[dict[IndexSymbol, IndexSequence]]: - if not hasattr(self.graph, "subgraphs"): - return None - for node in self.graph.subgraphs[self.subgraph_name].nodes: + for node in self.get_root_graph().subgraphs[self.subgraph_name].nodes: if isinstance(output := get_custom(node), Output): return_vals = output.return_vals[0] return ( - [val.index for val in return_vals] - if isinstance(return_vals, list) - else None + [ + get_custom(val).acc_index + if isinstance(get_custom(val), MMA) + else val.index + for val in return_vals + ] + if isinstance(return_vals, (Sequence)) + else return_vals.index ) @index.setter @@ -1111,11 +1128,12 @@ def indexing_dims(self) -> list[IndexExpr]: @property def index(self) -> dict[IndexSymbol, IndexSequence]: custom = get_custom(self.value) - if custom.index is None: + custom_index = custom.index + if custom_index is None: return None if not isinstance(custom, Reduction): return custom.index - assert isinstance(custom.index, list) and self.res_idx < len( + assert isinstance(custom_index, Sequence) and self.res_idx < len( custom.indexing_dims ) return custom.index[self.res_idx] diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index e80dc2cb..7bcb6f57 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -53,6 +53,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): anchorOpTypes = (Read, Write, MMA, ReduceOp) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) +legalSubtypes = (IterArg,) nonPropagatableTypes = anchorOpTypes + noHandleTypes @@ -61,17 +62,27 @@ def is_anchor_op(node: fx.Node): def propagatable_op(node: fx.Node): - return not isinstance(get_custom(node), nonPropagatableTypes) + custom_node = get_custom(node) + return not isinstance(custom_node, nonPropagatableTypes) or isinstance( + custom_node, legalSubtypes + ) + +def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: + """ + This function will attempt to resolve binaryOp conflicts + by inserting broadcastOp. It will then propagate the resolutions, + and return the list of fx.Nodes that we have resolved. + """ -def handle_binaryop_conflict(custom_node: CustomOp): # Analyze if we can resolve conflict with broadcast. lhs = get_custom(custom_node.lhs) rhs = get_custom(custom_node.rhs) lhs_dim_set = set(lhs.type.symbolic_shape) rhs_dim_set = set(rhs.type.symbolic_shape) if lhs_dim_set == rhs_dim_set: - raise ValueError("Cannot broadcast if lhs and rhs is already same.") + # Could be caused by consumers(likely also binaryOp) of this node. + return [] if lhs_dim_set.isdisjoint(rhs_dim_set): raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.") # Determine the correct indexSize for binaryOp and insert broadcasting. @@ -84,7 +95,8 @@ def handle_binaryop_conflict(custom_node: CustomOp): propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op) for node in propagated_resolutions: get_custom(node).index = dst_op.index - return propagated_resolutions + resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op) + return propagated_resolutions.union(resolved_resolutions) # Returns True iff all conflicts are handled succesfully. @@ -155,11 +167,26 @@ def determine_thread_shapes(trace: CapturedTrace): for anchor_op in anchor_ops: custom = get_custom(anchor_op) index_sizes = get_custom_dim_sizes(custom) - if isinstance(custom, (Read, ReduceOp)): + if isinstance(custom, Read): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) thread_size_to_ops[index_sizes] = thread_size_to_ops.get( index_sizes, set([]) ).union(fwd_slice) + elif isinstance(custom, ReduceOp): + fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) + bwd_slice = set() + if custom.init != None and not isinstance( + get_custom(custom.init), ReduceOp + ): + bwd_slice = capture_backward_slice(custom.init, propagatable_op) + reduce_dims = frozenset( + [DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim] + ) + thread_size_to_ops[reduce_dims] = ( + thread_size_to_ops.get(reduce_dims, set([])) + .union(fwd_slice) + .union(bwd_slice) + ) elif isinstance(custom, Write): bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) thread_size_to_ops[index_sizes] = thread_size_to_ops.get( diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index ead756ae..e052f90a 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -509,18 +509,23 @@ def get_users( init_arg_idx = custom.init_args.index(node) users.append(custom.iter_args[init_arg_idx]) continue - if isinstance(custom, Output) and reduction: + if isinstance(custom, Output): # Map output to get result return_vals = custom.return_vals[0] - get_results = sorted( - [x for x in reduction.users if isinstance(get_custom(x), GetResult)], - lambda x: get_custom(x).res_idx, - ) - if isinstance(return_vals, list): - output_idx = return_vals.index(node) - users.append(get_results[output_idx]) + parent_reduction = custom.graph.parent_op + if not isinstance(return_vals, (list, tuple)): + users.append(next(iter(parent_reduction.users))) else: - users.append(get_results[0]) + # Handles case where DCE eliminate unused GetResult. + get_results = { + get_custom(x).res_idx: x + for x in parent_reduction.users + if isinstance(get_custom(x), GetResult) + } + output_idx = return_vals.index(node) + # Sometime IterArg only used within the tkw.Reduction region + if output_idx in get_results: + users.append(get_results[output_idx]) continue users.append(user) return users, reduction diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index fb5c5912..07ca9ab1 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -231,6 +231,9 @@ def _trace_and_get_kernel_signature( # Set indices. set_node_indices(graph, self.constraints) + # Analyze Thread Shapes per Op. + determine_thread_shapes(graph) + # Expansion expand_graph(graph, self.constraints) @@ -249,9 +252,6 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) - # Analyze Thread Shapes per Op. - determine_thread_shapes(graph) - # Align sizes to WG/Tile sizes # This pass changes indexing keys, which can interfere with other passes, # so it should be called close to the end of pipeline. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 430b3b25..d38a8259 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -964,9 +964,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-8: amdgpu.mfma -# This test is used to check two things +# This test is used to check three things # 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works # 2. ReduceOp lowering works using constraints from MMA (not just vector_shape). +# 3. We can propagate layout of multiple Reduction results through IterArg/GetResult +# and observe that broadcast is being generated to resolve binaryOp. @run_test def test_gemm_and_reduce(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] @@ -987,8 +989,7 @@ def test_gemm_and_reduce(): def gemm( a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.f16], - d: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], ): c_reg = tkl.Register[M, N, tkl.f32](0.0) init_max = tkl.Register[M, tkl.f16](-1e6) @@ -1004,8 +1005,8 @@ def repeat( return partial_max, acc res_max, res_mm = repeat - tkw.write(res_max, c, elements_per_thread=1) - tkw.write(res_mm, d, elements_per_thread=STORE_ELEMS_PER_THREAD) + res = res_mm / tkw.cast(res_max, tkl.f32) + tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) with tk.gen.TestLaunchContext( { @@ -1024,22 +1025,25 @@ def repeat( ): a = torch.randn(64, 32, dtype=torch.float16) b = torch.randn(128, 32, dtype=torch.float16) - c = torch.zeros(64, dtype=torch.float16) - d = torch.zeros(64, 128, dtype=torch.float32) - print(gemm(a, b, c, d).module_op) + c = torch.zeros(64, 128, dtype=torch.float32) + print(gemm(a, b, c).module_op) # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index # Tile Reduction Loop # Note: Shape is 32x20 instead of 32x16 because of padding to avoid bank conflicts - # CHECK: %{{.*}}:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] # CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}}) # CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space>, vector<4xf16> # CHECK-COUNT-6: gpu.shuffle xor # CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}} # CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]] # CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32> + # CHECK: %[[MAX_EXT:.+]] = arith.extf %[[LOOP]]#0 : vector<1xf16> to vector<1xf32> + # CHECK: %[[BCAST_SRC:.+]] = vector.extract %[[MAX_EXT]][0] : f32 from vector<1xf32> + # CHECK: %[[BROADCAST:.+]] = vector.splat %19 : vector<4xf32> + # CHECK: arith.divf %[[LOOP]]#1, %[[BROADCAST]] : vector<4xf32> @run_test @@ -1757,14 +1761,16 @@ def test( # CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> # CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> - # 1st Broadcast-ADD RHS + # 1st Broadcast RHS # CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16> # CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16> - # CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16> - # 2nd Broadcast-ADD RHS + # 2nd Broadcast RHS # CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16> # CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16> + + # Broadcast-ADD RHS + # CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16> # CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16> diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 446544e7..5635c8a9 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -266,13 +266,13 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: @@ -452,13 +452,13 @@ def test_batched_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: @@ -594,7 +594,7 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 21d15ce8..7371a7bb 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -152,13 +152,13 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # Reduction subgraph: # CHECK: %acc_0_0_0