From 383994111d88cbd9b3272ed3aa42af39e5965a55 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Mon, 4 Nov 2024 14:00:22 -0800 Subject: [PATCH] Add type inference This PR adds a type inference pass to wave. Previously, the types were infered by looking up types from neighbors resulting in inefficient type inference. Instead, we now introduce a pass that infers the types for all operators in the graph and the inferred type is then stored in the node. Signed-off-by: Harsh Menon --- iree/turbine/kernel/ops/wave_ops.py | 94 +++++---- iree/turbine/kernel/wave/expansion.py | 2 +- iree/turbine/kernel/wave/type_inference.py | 21 ++ iree/turbine/kernel/wave/wave.py | 4 + lit_tests/kernel/wave/barriers.py | 3 + lit_tests/kernel/wave/expansion.py | 10 + .../kernel/wave/index_sequence_analysis.py | 2 + .../kernel/wave/minimize_global_loads.py | 2 + lit_tests/kernel/wave/promotion.py | 6 +- lit_tests/kernel/wave/scheduling.py | 2 + tests/kernel/wave/scheduling_test.py | 2 + tests/kernel/wave/type_inference_test.py | 199 ++++++++++++++++++ tests/kernel/wave/visualization_test.py | 2 + 13 files changed, 306 insertions(+), 43 deletions(-) create mode 100644 iree/turbine/kernel/wave/type_inference.py create mode 100644 tests/kernel/wave/type_inference_test.py diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 89b96ccf..6e38ea66 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -338,7 +338,7 @@ def custom_string(self, value_map: dict[str, str]) -> str: vars_str = ", ".join(vars_list) return f"{self.tkw_op_name}({vars_str})" - def add_to_graph(self, region_graph: RegionGraph) -> fx.Node: + def add_to_graph(self, region_graph: RegionGraph, type: Any = None) -> fx.Node: arg_list = tuple([value for _, value in vars(self).items()]) self.graph = region_graph self.fx_node = region_graph.create_node( @@ -350,6 +350,10 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node: self.fx_node.tkw_op = self.__class__ self.fx_node.tkw_op_name = self.tkw_op_name self.fx_node.index = None + if type is None: + get_custom(self.fx_node).infer_type() + else: + self.fx_node.type = type return self.fx_node def _add_proxy_to_graph(self, region_graph: RegionGraph): @@ -556,6 +560,23 @@ def vector_shapes(self) -> dict[IndexSymbol, int]: def vector_shapes(self, value: dict[IndexSymbol, int]): self.fx_node.vector_shapes = value + @property + def type(self) -> Any: + if hasattr(self.fx_node, "type"): + return self.fx_node.type + return None + + @type.setter + def type(self, value: Any): + self.fx_node.type = value + + def infer_type(self): + """ + Infer the type of this operator using the types + of its arguments. + """ + pass + def align_index(self, constraints: list["Constraint"]) -> None: """ Align index to WG/Tile sizes. @@ -602,13 +623,13 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: + def infer_type(self): lhs_type = get_custom(self.lhs).type rhs_type = get_custom(self.rhs).type has_same_type = has_same_custom_type(lhs_type, rhs_type) if has_same_type: - return lhs_type + self.type = lhs_type + return lhs_dim_set = set(lhs_type.symbolic_shape) rhs_dim_set = set(rhs_type.symbolic_shape) if lhs_dim_set.isdisjoint(rhs_dim_set): @@ -616,7 +637,7 @@ def type(self) -> Memory: "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." ) broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type - return broadcasted_type + self.type = broadcasted_type @define_interface_op("exp2") @@ -637,10 +658,9 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: + def infer_type(self): src_type = get_custom(self.arg).type - return src_type + self.type = src_type @final @@ -868,9 +888,8 @@ def rhs_type(self) -> Memory: def acc_type(self) -> Memory: return get_custom(self.acc).type - @property - def type(self) -> Memory: - return self.acc_type + def infer_type(self): + self.type = self.acc_type def operand_index( self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr] @@ -925,6 +944,7 @@ def reduction_dim(self, value: IndexSymbol): @define_op("read") @dataclass class Read(CustomOp): + memory: fx.Proxy elements_per_thread: Optional[Any] = None mapping: Optional[IndexMapping] = None @@ -937,10 +957,9 @@ def indexing_dims(self) -> list[IndexSymbol]: # TODO: This could contain ints. return list(self.memory_type.symbolic_shape) - @property - def type(self) -> "Register": + def infer_type(self): dtype = self.memory_type.dtype - return Register[*self.indexing_dims, dtype] + self.type = Register[*self.indexing_dims, dtype] @property def memory_type(self) -> "Memory": @@ -1052,12 +1071,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: captured_vars.append(nested_node) return captured_vars - @property - def type(self) -> Memory | Register | list[Memory | Register]: + def infer_type(self): res_types = [get_custom(x).type for x in self.init_args] if len(res_types) == 1: res_types = res_types[0] - return res_types + self.type = res_types def outputs(self, graph: fx.Graph) -> list[fx.Node]: for node in graph.nodes: @@ -1110,11 +1128,12 @@ def indexing_dims(self) -> list[IndexSymbol]: if self.mapping is not None: return list(self.mapping.input_shape) # TODO: This could contain ints. - return list(self.type.symbolic_shape) + return list(self.memory_type.symbolic_shape) - @property - def type(self) -> "Memory": - return get_custom(self.memory).type + def infer_type(self): + address_space = self.memory_type.address_space + dtype = self.memory_type.dtype + self.type = Memory[*self.indexing_dims, address_space, dtype] @property def memory_type(self) -> "Memory": @@ -1144,13 +1163,12 @@ class GetResult(CustomOp): value: fx.Node res_idx: int - @property - def type(self) -> "Memory": + def infer_type(self): src_type = get_custom(self.value).type if isinstance(src_type, list): - return src_type[self.res_idx] + self.type = src_type[self.res_idx] else: - return src_type + self.type = src_type @property def indexing_dims(self) -> list[IndexExpr]: @@ -1200,14 +1218,14 @@ class Extract(CustomOp): register_: fx.Proxy offset: IndexExpr | int - @property - def type(self) -> "Register": + def infer_type(self): # Intuition here is we are trying to extract an element # from fastest dim => we reduce the fastest dim. src_type = get_custom(self.register_).type # Return itself if just 0-D/1-D symbolic. if len(src_type.symbolic_shape) <= 1: - return src_type + self.type = src_type + return # Typically fastest dim is the last dimension, # If non-unit dim exists => non-unit dim is fastest dim. @@ -1220,7 +1238,7 @@ def type(self) -> "Register": dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0] dst_shape.remove(dim_to_remove) dst_type = Register[*dst_shape, src_type.dtype] - return dst_type + self.type = dst_type @define_op("extract_slice") @@ -1297,12 +1315,8 @@ def indexing_dims(self) -> list[IndexSymbol]: dst_indexing = [dim for dim in src_indexing if dim != self.dim] return dst_indexing - @property - def type(self) -> Memory: + def infer_type(self): if isinstance(self.arg, Sequence): - # Local import to break circular dep. - from ..wave.utils import all_equal - src_types = [get_custom(arg).type for arg in self.arg] ref_shape = src_types[0].symbolic_shape ref_dtype = src_types[0].dtype @@ -1318,7 +1332,7 @@ def type(self) -> Memory: src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] dst_type = Register[*reduced_dims, src_type.dtype] - return dst_type + self.type = dst_type @property def num_reduction_dims(self) -> int: @@ -1376,10 +1390,9 @@ class CastOp(CustomOp, ABC): def indexing_dims(self) -> list[IndexSymbol]: return get_custom(self.arg).indexing_dims - @property - def type(self) -> Memory: + def infer_type(self): src_shape = get_custom(self.arg).type.symbolic_shape - return Register[*src_shape, self.dtype] + self.type = Register[*src_shape, self.dtype] @define_op("permute") @@ -1397,13 +1410,12 @@ class Permute(CustomOp, ABC): def indexing_dims(self) -> list[IndexExpr]: return self.target_shape - @property - def type(self) -> Register: + def infer_type(self): src_type = get_custom(self.arg).type assert set(src_type.symbolic_shape) == set( self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" - return Register[*self.target_shape, src_type.dtype] + self.type = Register[*self.target_shape, src_type.dtype] def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]: diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 55c1db85..ec1c7d18 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -336,7 +336,7 @@ def _expand_reduction( # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) - new_node.add_to_graph(reduction.graph) + new_node.add_to_graph(reduction.graph, arg.type) new_node.fx_node.name = get_expanded_name(new_node, dims) context[ (reduction, get_indexed_dims(dims, expand_dims), arg_idx) diff --git a/iree/turbine/kernel/wave/type_inference.py b/iree/turbine/kernel/wave/type_inference.py new file mode 100644 index 00000000..db574cfc --- /dev/null +++ b/iree/turbine/kernel/wave/type_inference.py @@ -0,0 +1,21 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..ops.wave_ops import * +from .._support.tracing import CapturedTrace +import torch.fx as fx +from ...support.logging import get_logger + +logger = get_logger("turbine.wave.type_inference") + + +def infer_types(trace: CapturedTrace | fx.Graph): + # Infer and set the types for all nodes in the graph. + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + custom.infer_type() + logger.debug(f"Setting type for {custom.fx_node} = {custom.type}") diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 07ca9ab1..7574f032 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -49,6 +49,7 @@ from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr +from .type_inference import infer_types import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature( # Initialize Vector shapes self.hardware_constraints[0].subs_vector_shapes(idxc.subs) + # Do type inference. + infer_types(graph) + # Promote the placeholders to the appropriate address space. promote_placeholders(graph, self.constraints) hoist_allocs(graph) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index c4c02ccd..6a67cfb9 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -14,6 +14,7 @@ from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -86,6 +87,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) set_node_indices(trace, constraints) expand_graph(trace, constraints) @@ -171,6 +173,7 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") IndexingContext.current().finalize() + infer_types(trace) read_nodes = get_read_nodes(graph) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 4f00c54e..1c86de37 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -6,6 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.index_sequence_analysis import ( set_node_indices, set_post_expansion_indices, @@ -69,6 +70,7 @@ def test_read_write_equal_sizes(): ): graph = read_write_same_size() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -150,6 +152,7 @@ def test_read_write(): ): graph = read_write_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -227,6 +230,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -413,6 +417,7 @@ def test_batched_gemm(): ): graph = batched_gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -591,6 +596,7 @@ def test_gemm_non_direct_acc(): ): graph = gemm_non_direct_acc() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -657,6 +663,7 @@ def test_tiled_max(): ): graph = tiled_max() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -688,6 +695,7 @@ def test_gemm_reduction_expansion_only(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -791,6 +799,7 @@ def py_arithmetic_different_dims(): ): graph = py_arithmetic_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -896,6 +905,7 @@ def test_chained_gemm_32x32x8(): ): graph = chained_gemm_32x32x8() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 36594f22..812edd6f 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -84,6 +85,7 @@ def test_gemm(): ): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index a6a61a17..f74a8764 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -10,6 +10,7 @@ from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -87,6 +88,7 @@ def test_gemm(): trace: CapturedTrace = gemm() visualize = False IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index c3836f4f..f1f348a7 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -7,6 +7,7 @@ import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -67,6 +68,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) print_trace(trace, False) # CHECK: %a @@ -116,6 +118,7 @@ def test_read_write_equal_sizes_different_address_spaces(): ): trace: CapturedTrace = read_write_same_size_different_address_spaces() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) print_trace(trace, False) # CHECK: %a @@ -170,10 +173,11 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") read_nodes = get_read_nodes(graph) + IndexingContext.current().finalize() + infer_types(trace) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) hoist_allocs(trace) - IndexingContext.current().finalize() print_trace(trace, False) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 87936809..00810403 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -96,6 +97,7 @@ def test_gemm_pipelined(): ): trace: CapturedTrace = gemm_pipelined() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 269d7378..b5fdd417 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -29,6 +29,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph from iree.turbine.kernel.ops.wave_ops import get_custom @@ -280,6 +281,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/tests/kernel/wave/type_inference_test.py b/tests/kernel/wave/type_inference_test.py new file mode 100644 index 00000000..6ce7efa2 --- /dev/null +++ b/tests/kernel/wave/type_inference_test.py @@ -0,0 +1,199 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import logging +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.type_inference import infer_types +from iree.turbine.kernel.ops.wave_ops import get_custom + + +class TypeInferenceTest(unittest.TestCase): + def testAttentionInference(self): + shape = (8, 128, 128, 64, 256) + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = MMAType.F32_16x16x16_F16 + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: Mvec, N: Nvec}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave_trace_only(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write( + res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + trace: CapturedTrace = base_attention() + IndexingContext.current().finalize() + infer_types(trace) + expected_type = { + "partial_sum": "Register[B, M].of(f32)", + "partial_max": "Register[B, M].of(f32)", + "acc": "Register[B, N, M].of(f32)", + "q": "Memory[B, M, K1].of(f16)", + "read": "Register[B, M, K1].of(f16)", + "k": "Memory[B, K2, K1].of(f16)", + "read_1": "Register[B, K2, K1].of(f16)", + "mma": "Register[B, K2, M].of(f32)", + "permute": "Register[B, M, K2].of(f32)", + "max_1": "Register[B, M].of(f32)", + "sub": "Register[B, M].of(f32)", + "exp2": "Register[B, M].of(f32)", + "sub_1": "Register[B, M, K2].of(f32)", + "exp2_1": "Register[B, M, K2].of(f32)", + "mul": "Register[B, M].of(f32)", + "sum_1": "Register[B, M].of(f32)", + "cast": "Register[B, M, K2].of(f16)", + "v": "Memory[B, N, K2].of(f16)", + "read_2": "Register[B, N, K2].of(f16)", + "mul_1": "Register[B, N, M].of(f32)", + "mma_1": "Register[B, N, M].of(f32)", + "c": "Memory[B, M, N].of(f32)", + "register_1": "Register[B, M].of(f32)", + "register_2": "Register[B, M].of(f32)", + "reduction": "[Register[B, M].of(f32), Register[B, M].of(f32), Register[B, N, M].of(f32)]", + "getitem": "Register[B, M].of(f32)", + "getitem_1": "Register[B, M].of(f32)", + "getitem_2": "Register[B, N, M].of(f32)", + "truediv": "Register[B, N, M].of(f32)", + "write": "Memory[B, N, M].of(f32)", + } + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + if custom.fx_node.name in expected_type: + assert str(custom.type) == expected_type[custom.fx_node.name] + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py index a7e84b0d..04a1959e 100644 --- a/tests/kernel/wave/visualization_test.py +++ b/tests/kernel/wave/visualization_test.py @@ -13,6 +13,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import get_custom @@ -93,6 +94,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints)