Skip to content

Commit

Permalink
Add type inference
Browse files Browse the repository at this point in the history
This PR adds a type inference pass to wave. Previously,
the types were infered by looking up types from neighbors
resulting in inefficient type inference.

Instead, we now introduce a pass that infers the types for
all operators in the graph and the inferred type is then
stored in the node.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Nov 5, 2024
1 parent ee62366 commit 286c235
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 43 deletions.
94 changes: 53 additions & 41 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def custom_string(self, value_map: dict[str, str]) -> str:
vars_str = ", ".join(vars_list)
return f"{self.tkw_op_name}({vars_str})"

def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
def add_to_graph(self, region_graph: RegionGraph, type: Any = None) -> fx.Node:
arg_list = tuple([value for _, value in vars(self).items()])
self.graph = region_graph
self.fx_node = region_graph.create_node(
Expand All @@ -350,6 +350,10 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
self.fx_node.tkw_op = self.__class__
self.fx_node.tkw_op_name = self.tkw_op_name
self.fx_node.index = None
if type is None:
get_custom(self.fx_node).infer_type()
else:
self.fx_node.type = type
return self.fx_node

def _add_proxy_to_graph(self, region_graph: RegionGraph):
Expand Down Expand Up @@ -556,6 +560,23 @@ def vector_shapes(self) -> dict[IndexSymbol, int]:
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

@property
def type(self) -> Any:
if hasattr(self.fx_node, "type"):
return self.fx_node.type
return None

@type.setter
def type(self, value: Any):
self.fx_node.type = value

def infer_type(self):
"""
Infer the type of this operator using the types
of its arguments.
"""
pass

def align_index(self, constraints: list["Constraint"]) -> None:
"""
Align index to WG/Tile sizes.
Expand Down Expand Up @@ -602,21 +623,21 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
return lhs_type
self.type = lhs_type
return
lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
return broadcasted_type
self.type = broadcasted_type


@define_interface_op("exp2")
Expand All @@ -637,10 +658,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
src_type = get_custom(self.arg).type
return src_type
self.type = src_type


@final
Expand Down Expand Up @@ -868,9 +888,8 @@ def rhs_type(self) -> Memory:
def acc_type(self) -> Memory:
return get_custom(self.acc).type

@property
def type(self) -> Memory:
return self.acc_type
def infer_type(self):
self.type = self.acc_type

def operand_index(
self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr]
Expand Down Expand Up @@ -925,6 +944,7 @@ def reduction_dim(self, value: IndexSymbol):
@define_op("read")
@dataclass
class Read(CustomOp):

memory: fx.Proxy
elements_per_thread: Optional[Any] = None
mapping: Optional[IndexMapping] = None
Expand All @@ -937,10 +957,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
# TODO: This could contain ints.
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Register":
def infer_type(self):
dtype = self.memory_type.dtype
return Register[*self.indexing_dims, dtype]
self.type = Register[*self.indexing_dims, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1052,12 +1071,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
captured_vars.append(nested_node)
return captured_vars

@property
def type(self) -> Memory | Register | list[Memory | Register]:
def infer_type(self):
res_types = [get_custom(x).type for x in self.init_args]
if len(res_types) == 1:
res_types = res_types[0]
return res_types
self.type = res_types

def outputs(self, graph: fx.Graph) -> list[fx.Node]:
for node in graph.nodes:
Expand Down Expand Up @@ -1110,11 +1128,12 @@ def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
return list(self.mapping.input_shape)
# TODO: This could contain ints.
return list(self.type.symbolic_shape)
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Memory":
return get_custom(self.memory).type
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1144,13 +1163,12 @@ class GetResult(CustomOp):
value: fx.Node
res_idx: int

@property
def type(self) -> "Memory":
def infer_type(self):
src_type = get_custom(self.value).type
if isinstance(src_type, list):
return src_type[self.res_idx]
self.type = src_type[self.res_idx]
else:
return src_type
self.type = src_type

@property
def indexing_dims(self) -> list[IndexExpr]:
Expand Down Expand Up @@ -1200,14 +1218,14 @@ class Extract(CustomOp):
register_: fx.Proxy
offset: IndexExpr | int

@property
def type(self) -> "Register":
def infer_type(self):
# Intuition here is we are trying to extract an element
# from fastest dim => we reduce the fastest dim.
src_type = get_custom(self.register_).type
# Return itself if just 0-D/1-D symbolic.
if len(src_type.symbolic_shape) <= 1:
return src_type
self.type = src_type
return

# Typically fastest dim is the last dimension,
# If non-unit dim exists => non-unit dim is fastest dim.
Expand All @@ -1220,7 +1238,7 @@ def type(self) -> "Register":
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
return dst_type
self.type = dst_type


@define_op("extract_slice")
Expand Down Expand Up @@ -1297,12 +1315,8 @@ def indexing_dims(self) -> list[IndexSymbol]:
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
def infer_type(self):
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
ref_shape = src_types[0].symbolic_shape
ref_dtype = src_types[0].dtype
Expand All @@ -1318,7 +1332,7 @@ def type(self) -> Memory:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
self.type = dst_type

@property
def num_reduction_dims(self) -> int:
Expand Down Expand Up @@ -1376,10 +1390,9 @@ class CastOp(CustomOp, ABC):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]
self.type = Register[*src_shape, self.dtype]


@define_op("permute")
Expand All @@ -1397,13 +1410,12 @@ class Permute(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
def infer_type(self):
src_type = get_custom(self.arg).type
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]
self.type = Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _expand_reduction(
# Add GetResult nodes for the corresponding dimensions
reduction.graph.inserting_after(reduction.fx_node)
new_node = GetResult(reduction.fx_node, len(new_output_args))
new_node.add_to_graph(reduction.graph)
new_node.add_to_graph(reduction.graph, arg.type)
new_node.fx_node.name = get_expanded_name(new_node, dims)
context[
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
Expand Down
21 changes: 21 additions & 0 deletions iree/turbine/kernel/wave/type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import *
from .._support.tracing import CapturedTrace
import torch.fx as fx
from ...support.logging import get_logger

logger = get_logger("turbine.wave.type_inference")


def infer_types(trace: CapturedTrace | fx.Graph):
# Infer and set the types for all nodes in the graph.
for subgraph in trace.region_graph.subgraphs.values():
for node in subgraph.nodes:
custom = get_custom(node)
custom.infer_type()
logger.debug(f"Setting type for {custom.fx_node} = {custom.type}")
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
from .type_inference import infer_types
import iree.turbine.kernel.lang as tkl
from .._support.tracing import (
CapturedTrace,
Expand Down Expand Up @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature(
# Initialize Vector shapes
self.hardware_constraints[0].subs_vector_shapes(idxc.subs)

# Do type inference.
infer_types(graph)

# Promote the placeholders to the appropriate address space.
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)
Expand Down
3 changes: 3 additions & 0 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
infer_types(trace)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
Expand Down Expand Up @@ -171,6 +173,7 @@ def test_gemm():
trace: CapturedTrace = gemm()
graph: fx.Graph = trace.get_subgraph("region_0")
IndexingContext.current().finalize()
infer_types(trace)
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
Expand Down
10 changes: 10 additions & 0 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.wave.index_sequence_analysis import (
set_node_indices,
set_post_expansion_indices,
Expand Down Expand Up @@ -69,6 +70,7 @@ def test_read_write_equal_sizes():
):
graph = read_write_same_size()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -150,6 +152,7 @@ def test_read_write():
):
graph = read_write_different_dims()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -227,6 +230,7 @@ def test_gemm():
):
graph = gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -413,6 +417,7 @@ def test_batched_gemm():
):
graph = batched_gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -591,6 +596,7 @@ def test_gemm_non_direct_acc():
):
graph = gemm_non_direct_acc()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -657,6 +663,7 @@ def test_tiled_max():
):
graph = tiled_max()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -688,6 +695,7 @@ def test_gemm_reduction_expansion_only():
):
graph = gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -791,6 +799,7 @@ def py_arithmetic_different_dims():
):
graph = py_arithmetic_different_dims()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -896,6 +905,7 @@ def test_chained_gemm_32x32x8():
):
graph = chained_gemm_32x32x8()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down
Loading

0 comments on commit 286c235

Please sign in to comment.