From 3bbe6f8a11e247a4a22bde43a3976a370659d298 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Wed, 30 Oct 2024 11:12:09 -0700 Subject: [PATCH] Add support for scheduling attention This PR adds support for scheduling in attention operators. --- iree/turbine/kernel/lang/global_symbols.py | 4 + .../kernel/wave/scheduling/graph_utils.py | 83 +++++++++++- .../wave/scheduling/loop_reconstruction.py | 68 ++++++++-- .../wave/scheduling/modulo_scheduling.py | 5 +- .../kernel/wave/scheduling/resources.py | 50 +++++-- lit_tests/kernel/wave/attention.py | 123 +++++++++++++++++- lit_tests/kernel/wave/codegen.py | 4 + lit_tests/kernel/wave/scheduling.py | 4 + tests/kernel/wave/scheduling_test.py | 7 +- tests/kernel/wave/wave_attention_test.py | 6 +- tests/kernel/wave/wave_gemm_test.py | 12 ++ 11 files changed, 330 insertions(+), 36 deletions(-) diff --git a/iree/turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py index e1966138..5fa542be 100644 --- a/iree/turbine/kernel/lang/global_symbols.py +++ b/iree/turbine/kernel/lang/global_symbols.py @@ -27,6 +27,10 @@ READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY") WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY") MMA_DELAY = index_symbol("$MMA_DELAY") +VALU_DELAY = index_symbol("$VALU_DELAY") +SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY") SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS") GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS") MMA_UNITS = index_symbol("$MMA_UNITS") +VALU_UNITS = index_symbol("$VALU_UNITS") +SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS") diff --git a/iree/turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py index 34a68356..4430c145 100644 --- a/iree/turbine/kernel/wave/scheduling/graph_utils.py +++ b/iree/turbine/kernel/wave/scheduling/graph_utils.py @@ -12,6 +12,8 @@ from dataclasses import dataclass import sympy import math +from functools import partial +import multiprocessing as mp T = index_symbol("$INITIATION_INTERVAL") @@ -157,7 +159,30 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]] return circuits -def all_pairs_longest_paths( +def all_pairs_longest_paths_helper( + graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int +): + print(f"Got u = {u}") + v = list(graph.nodes)[i] + for w in graph.nodes: + dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)]) + return v, dist + + +def test(i: int): + print(f"got {i}") + + +def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int): + """ + This function is called once for a different value of i. + """ + for j in range(N): + D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j]) + return i, D[i] + + +def all_pairs_longest_paths_symbolic( graph: fx.Graph, edges: list[Edge], ) -> dict[tuple[fx.Node, fx.Node], IndexExpr]: @@ -181,6 +206,51 @@ def all_pairs_longest_paths( return D +def all_pairs_longest_paths( + graph: fx.Graph, + edges: list[Edge], + T: int, +) -> dict[tuple[fx.Node, fx.Node], IndexExpr]: + """ + For each node in the graph, compute the longest path to all other nodes. + Uses the Floyd-Warshall algorithm and assumes that the cycles don't + have positive weights. This function computes the distances in parallel + by parallelizing across the start nodes. + """ + N = len(graph.nodes) + D = np.zeros((N, N), dtype=np.float32) + negative_inf = -np.inf + for i in range(N): + for j in range(N): + D[i, j] = negative_inf + + all_nodes = list(graph.nodes) + for edge in edges: + i = all_nodes.index(edge._from) + j = all_nodes.index(edge._to) + D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T + + # Parallel implementation + pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) + for k in range(N): + func = partial(all_pairs_longest_path_parallel, N, D, k) + results = pool.map(func, range(N)) + for result in results: + D[result[0]] = result[1] + pool.close() + pool.join() + + # Convert from index to node based representation. + G: dict[tuple[fx.Node, fx.Node], int] = {} + for i, from_node in enumerate(graph.nodes): + for j, to_node in enumerate(graph.nodes): + if np.isinf(D[i, j]) or i == j: + continue + G[(from_node, to_node)] = int(D[i, j]) + + return G + + def evaluate_all_pairs_longest_paths( D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int ) -> dict[tuple[fx.Node, fx.Node], int]: @@ -190,7 +260,8 @@ def evaluate_all_pairs_longest_paths( """ D_static = dict(D) for key in D_static: - D_static[key] = D_static[key].subs(T, initiation_interval) + if isinstance(D_static[key], sympy.Expr): + D_static[key] = D_static[key].subs(T, initiation_interval) # Remove the negative infinity values and edges to self. for k in list(D_static.keys()): if math.isinf(D_static[k]) or k[0] == k[1]: @@ -244,8 +315,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight: weight = EdgeWeight(0, delay_table[Operation.MMA]) case IterArg(): weight = EdgeWeight(1, 0) - case CastOp(): + case CastOp() | Extract() | Permute() | Broadcast(): weight = EdgeWeight(0, delay_table[Operation.NOOP]) + case UnaryPyOp(): + weight = EdgeWeight(0, delay_table[Operation.VALU]) + case BinaryPyOp(): + weight = EdgeWeight(0, delay_table[Operation.VALU]) + case ShuffleOp(): + weight = EdgeWeight(0, delay_table[Operation.SHUFFLE]) case _: raise ValueError(f"Unsupported node type: {custom_node}") weight.delay = subs_idxc(weight.delay) diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py index 8955fdea..579dae2c 100644 --- a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -9,6 +9,8 @@ GetResult, get_custom, SchedulingGroupBarrier, + MMA, + NewRegister, ) from .modulo_scheduling import ModuloScheduler from ..utils import ( @@ -99,10 +101,11 @@ def add_nodes_by_schedule( # Set the index for the new node by substituting the induction variable # for the current iteration. new_node.index = node.index - for dim in new_node.index: - new_node.index[dim] = new_node.index[dim].subs( - {induction_variable: current_induction_variables[iteration]} - ) + if new_node.index: + for dim in new_node.index: + new_node.index[dim] = new_node.index[dim].subs( + {induction_variable: current_induction_variables[iteration]} + ) # Add scheduling parameters for debugging. new_node.scheduling_parameters = node.scheduling_parameters # Update the rotating registers and argument context for the current node (if applicable). @@ -154,6 +157,21 @@ def push_placeholders( arg_context.map_arg_all(node, root_node) +def add_missing_registers(graph: fx.Graph): + """ """ + for node in graph.nodes: + custom = get_custom(node) + if isinstance(custom, MMA): + acc = get_custom(custom.acc) + if acc.graph != custom.graph: + with custom.graph.inserting_before(node): + register = NewRegister( + acc.shape, acc.dtype, acc.value + ).add_to_graph(custom.graph) + register.index = acc.index + custom.update_arg("acc", register) + + def construct_prologue( reduction_subgraph: fx.Graph, reduction: Reduction, @@ -213,6 +231,12 @@ def construct_prologue( new_init_args.append(mapped_init_arg) reduction.init_args = new_init_args + # Add missing registers. Since registers are not present + # in the scheduling code, we could end up with a situation where + # we move mma ops outside the reduction that do not have a corresponding + # register. We remedy this in the function below. + add_missing_registers(reduction.graph) + def flatten_dict_values( rotating_registers: dict[fx.Node, list[fx.Node]] @@ -386,6 +410,12 @@ def construct_kernel( "kernel.png", ) + # Add missing registers. Since registers are not present + # in the scheduling code, we could end up with a situation where + # we move mma ops outside the reduction that do not have a corresponding + # register. We remedy this in the function below. + add_missing_registers(pipelined_reduction_graph) + return pipelined_reduction, pipelined_reduction_graph @@ -422,16 +452,30 @@ def construct_epilogue( scheduler.num_stages, ) - existing_get_results: list[GetResult] = sorted( - [x for x in pipelined_reduction.users if isinstance(x, GetResult)], - key=lambda x: x.res_idx, - ) - existing_users = {x: x.users for x in existing_get_results} + existing_get_results: list[GetResult] = [ + x for x in pipelined_reduction.users if isinstance(x, GetResult) + ] + existing_indices = [x.res_idx for x in existing_get_results] # Map the results from the kernel to the init args (for stages). - for iter_arg, get_result in zip( - reduction.iter_args(reduction_subgraph), existing_get_results - ): + # The number of iter args may not be the same as the number of get results + # and so we have to add additional get results for the missing iter args. + iter_args = reduction.iter_args(reduction_subgraph) + for i in range(len(iter_args)): + if i in existing_indices: + continue + with pipelined_reduction.graph.inserting_before( + existing_get_results[0].fx_node.next + ): + result = GetResult(pipelined_reduction.fx_node, i).add_to_graph( + pipelined_reduction.graph + ) + existing_get_results.append(get_custom(result)) + + existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx) + existing_users = {x: x.users for x in existing_get_results} + + for iter_arg, get_result in zip(iter_args, existing_get_results): arg_context.map_arg_all(iter_arg, get_result.fx_node) with pipelined_reduction.graph.inserting_before( diff --git a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py index 82940113..00c6cd78 100644 --- a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -106,9 +106,6 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: # Initialize initiation interval. T0 = int(max(self.compute_resource_ii(), self.compute_recurrence_ii(sccs))) - # Compute symbolic all pairs longest path. - e_star_symbolic = all_pairs_longest_paths(self.graph, self.edges) - # Generate the schedule. # TODO: Come up with a better heuristic on an upper bound for the initiation interval. T_max_range = 3 * T0 @@ -116,7 +113,7 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: for T in range(T0, T0 + T_max_range): logger.debug(f"Trying initiation interval: {T}.") self.RT = np.zeros((T, len(self.resources))) - self.e_star = evaluate_all_pairs_longest_paths(e_star_symbolic, T) + self.e_star = all_pairs_longest_paths(self.graph, self.edges, T) logger.debug(f"All Pairs Longest Paths: {self.e_star}.") self.schedule: dict[fx.Node, int] = {} for _, scc in topological_sort(sccs).items(): diff --git a/iree/turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py index 15dd76d3..cc19a7fa 100644 --- a/iree/turbine/kernel/wave/scheduling/resources.py +++ b/iree/turbine/kernel/wave/scheduling/resources.py @@ -15,6 +15,12 @@ get_custom, CustomOp, CastOp, + UnaryPyOp, + BinaryPyOp, + ShuffleOp, + Permute, + Extract, + Broadcast, ) import torch.fx as fx from enum import Enum @@ -23,7 +29,13 @@ # This table contains the number of functional units available for each operation. def get_available_resources() -> list[int]: - resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS] + resources = [ + GLOBAL_MEMORY_UNITS, + SHARED_MEMORY_UNITS, + MMA_UNITS, + VALU_UNITS, + SHUFFLE_UNITS, + ] return np.array([int(subs_idxc(x)) for x in resources]) @@ -37,8 +49,11 @@ class Operation(Enum): VALU = "valu" SALU = "salu" NOOP = "noop" + SHUFFLE = "shuffle" +SCHEDULING_NOOPS = (IterArg, Permute, Extract, Broadcast, CastOp) + # This table contains the cycles required to execute each operation. delay_table = { Operation.READ_SHARED: READ_SHARED_DELAY, @@ -47,17 +62,21 @@ class Operation(Enum): Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY, Operation.MMA: MMA_DELAY, Operation.NOOP: 0, + Operation.VALU: VALU_DELAY, + Operation.SHUFFLE: SHUFFLE_DELAY, } # This table contains the resource usage for each operation. # Operations can use more than one resource for more than one cycle. resource_reservation_table = { - Operation.READ_SHARED: np.array([[0, 1, 0]]), - Operation.WRITE_SHARED: np.array([[0, 1, 0]]), - Operation.READ_GLOBAL: np.array([[1, 0, 0]]), - Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]), - Operation.MMA: np.array([[0, 0, 1]]), - Operation.NOOP: np.array([[0, 0, 0]]), + Operation.READ_SHARED: np.array([[0, 1, 0, 0, 0]]), + Operation.WRITE_SHARED: np.array([[0, 1, 0, 0, 0]]), + Operation.READ_GLOBAL: np.array([[1, 0, 0, 0, 0]]), + Operation.WRITE_GLOBAL: np.array([[1, 0, 0, 0, 0]]), + Operation.MMA: np.array([[0, 0, 1, 0, 0]]), + Operation.NOOP: np.array([[0, 0, 0, 0, 0]]), + Operation.VALU: np.array([[0, 0, 0, 1, 0]]), + Operation.SHUFFLE: np.array([[0, 0, 0, 0, 1]]), } @@ -76,10 +95,12 @@ def get_custom_operation_type(custom: CustomOp) -> Operation: ) elif isinstance(custom, MMA): return Operation.MMA - elif isinstance(custom, IterArg): - return Operation.NOOP - elif isinstance(custom, Output): + elif isinstance(custom, SCHEDULING_NOOPS + (Output,)): return Operation.NOOP + elif isinstance(custom, (UnaryPyOp, BinaryPyOp)): + return Operation.VALU + elif isinstance(custom, ShuffleOp): + return Operation.SHUFFLE else: return None @@ -106,8 +127,13 @@ def annotate_resource_usage( ) elif isinstance(custom, MMA): custom.rrt = resource_reservation_table[Operation.MMA] - elif isinstance(custom, (IterArg, CastOp)): - iter_args.append(node) + elif isinstance(custom, ShuffleOp): + custom.rrt = resource_reservation_table[Operation.SHUFFLE] + elif isinstance(custom, (UnaryPyOp, BinaryPyOp)): + custom.rrt = resource_reservation_table[Operation.VALU] + elif isinstance(custom, SCHEDULING_NOOPS): + if isinstance(custom, IterArg): + iter_args.append(node) custom.rrt = resource_reservation_table[Operation.NOOP] elif isinstance(custom, Output): output = node diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index a862499c..2b0a4f52 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -31,6 +31,127 @@ STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD +@run_test +def test_attention_pipelined(): + shape = (8, 128, 128, 64, 256) + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = tkw.MMAType.F32_16x16x16_F16 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: 16, N: 16}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave(constraints) + def base_attention_pipelined( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 12, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, + } + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=True, + use_scheduling_barriers=False, + ): + torch.manual_seed(0) + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + print(base_attention_pipelined(q, k, v, output).module_op) + + # CHECK: func.func @base_attention_pipelined + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-16: {{.*}} = amdgpu.mfma + # CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + + @run_test def test_attention_32x32x8(): shape = (8, 128, 128, 64, 256) @@ -205,9 +326,7 @@ def repeat( ): imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - # b_reg: tkw.Register[B, N, K, tkl.f16] k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - # acc: tkw.Register[B, N, M, tkl.f32] inner_acc = tkw.mma(k_reg, q_reg, imm_reg) x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) m_j = tkw.max(x_j, partial_max, dim=K2) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index fc5e482c..cb0d5080 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1140,9 +1140,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, }, canonicalize=True, schedule=True, diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index afa6065b..87936809 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -88,6 +88,10 @@ def test_gemm_pipelined(): SHARED_MEMORY_UNITS: 2, GLOBAL_MEMORY_UNITS: 2, MMA_UNITS: 2, + VALU_DELAY: 1, + VALU_UNITS: 2, + SHUFFLE_DELAY: 1, + SHUFFLE_UNITS: 2, } ): trace: CapturedTrace = gemm_pipelined() diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index d8728e3d..269d7378 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -178,9 +178,8 @@ def testGraphUtils(self): def testAPLP(self): graph, weighted_edges, nodes = self.create_weighted_graph() - D = all_pairs_longest_paths(graph, weighted_edges) T = 4 - D3 = evaluate_all_pairs_longest_paths(D, T) + D3 = all_pairs_longest_paths(graph, weighted_edges, T) assert D3[(nodes["a"], nodes["b"])] == 2 assert D3[(nodes["a"], nodes["c"])] == 3 assert D3[(nodes["a"], nodes["d"])] == 4 @@ -273,6 +272,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: SHARED_MEMORY_UNITS: 2, GLOBAL_MEMORY_UNITS: 2, MMA_UNITS: 2, + VALU_DELAY: 1, + VALU_UNITS: 2, + SHUFFLE_DELAY: 1, + SHUFFLE_UNITS: 2, } with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): trace: CapturedTrace = gemm() diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 792d9cff..f62bd103 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -349,7 +349,7 @@ def repeat( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) -@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("enable_scheduling", [False, True]) @pytest.mark.parametrize( "mfma_variant", [ @@ -474,9 +474,13 @@ def repeat( READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} if run_bench: diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 7c512b24..7bbe0629 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -142,9 +142,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} if run_bench: @@ -261,9 +265,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} if run_bench: @@ -376,9 +384,13 @@ def repeat( READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} if run_bench: