Skip to content

Commit

Permalink
Add support for scheduling attention
Browse files Browse the repository at this point in the history
This PR adds support for scheduling in
attention operators.
  • Loading branch information
harsh-nod committed Nov 5, 2024
1 parent 8d02bc0 commit 3bbe6f8
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 36 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
MMA_DELAY = index_symbol("$MMA_DELAY")
VALU_DELAY = index_symbol("$VALU_DELAY")
SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY")
SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS")
GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS")
MMA_UNITS = index_symbol("$MMA_UNITS")
VALU_UNITS = index_symbol("$VALU_UNITS")
SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS")
83 changes: 80 additions & 3 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataclasses import dataclass
import sympy
import math
from functools import partial
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")

Expand Down Expand Up @@ -157,7 +159,30 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]]
return circuits


def all_pairs_longest_paths(
def all_pairs_longest_paths_helper(
graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int
):
print(f"Got u = {u}")
v = list(graph.nodes)[i]
for w in graph.nodes:
dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)])
return v, dist


def test(i: int):
print(f"got {i}")


def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int):
"""
This function is called once for a different value of i.
"""
for j in range(N):
D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j])
return i, D[i]


def all_pairs_longest_paths_symbolic(
graph: fx.Graph,
edges: list[Edge],
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
Expand All @@ -181,6 +206,51 @@ def all_pairs_longest_paths(
return D


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Uses the Floyd-Warshall algorithm and assumes that the cycles don't
have positive weights. This function computes the distances in parallel
by parallelizing across the start nodes.
"""
N = len(graph.nodes)
D = np.zeros((N, N), dtype=np.float32)
negative_inf = -np.inf
for i in range(N):
for j in range(N):
D[i, j] = negative_inf

all_nodes = list(graph.nodes)
for edge in edges:
i = all_nodes.index(edge._from)
j = all_nodes.index(edge._to)
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
for i, from_node in enumerate(graph.nodes):
for j, to_node in enumerate(graph.nodes):
if np.isinf(D[i, j]) or i == j:
continue
G[(from_node, to_node)] = int(D[i, j])

return G


def evaluate_all_pairs_longest_paths(
D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int
) -> dict[tuple[fx.Node, fx.Node], int]:
Expand All @@ -190,7 +260,8 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
D_static[key] = D_static[key].subs(T, initiation_interval)
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down Expand Up @@ -244,8 +315,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight:
weight = EdgeWeight(0, delay_table[Operation.MMA])
case IterArg():
weight = EdgeWeight(1, 0)
case CastOp():
case CastOp() | Extract() | Permute() | Broadcast():
weight = EdgeWeight(0, delay_table[Operation.NOOP])
case UnaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case BinaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case ShuffleOp():
weight = EdgeWeight(0, delay_table[Operation.SHUFFLE])
case _:
raise ValueError(f"Unsupported node type: {custom_node}")
weight.delay = subs_idxc(weight.delay)
Expand Down
68 changes: 56 additions & 12 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
GetResult,
get_custom,
SchedulingGroupBarrier,
MMA,
NewRegister,
)
from .modulo_scheduling import ModuloScheduler
from ..utils import (
Expand Down Expand Up @@ -99,10 +101,11 @@ def add_nodes_by_schedule(
# Set the index for the new node by substituting the induction variable
# for the current iteration.
new_node.index = node.index
for dim in new_node.index:
new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
if new_node.index:
for dim in new_node.index:
new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
# Add scheduling parameters for debugging.
new_node.scheduling_parameters = node.scheduling_parameters
# Update the rotating registers and argument context for the current node (if applicable).
Expand Down Expand Up @@ -154,6 +157,21 @@ def push_placeholders(
arg_context.map_arg_all(node, root_node)


def add_missing_registers(graph: fx.Graph):
""" """
for node in graph.nodes:
custom = get_custom(node)
if isinstance(custom, MMA):
acc = get_custom(custom.acc)
if acc.graph != custom.graph:
with custom.graph.inserting_before(node):
register = NewRegister(
acc.shape, acc.dtype, acc.value
).add_to_graph(custom.graph)
register.index = acc.index
custom.update_arg("acc", register)


def construct_prologue(
reduction_subgraph: fx.Graph,
reduction: Reduction,
Expand Down Expand Up @@ -213,6 +231,12 @@ def construct_prologue(
new_init_args.append(mapped_init_arg)
reduction.init_args = new_init_args

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
# register. We remedy this in the function below.
add_missing_registers(reduction.graph)


def flatten_dict_values(
rotating_registers: dict[fx.Node, list[fx.Node]]
Expand Down Expand Up @@ -386,6 +410,12 @@ def construct_kernel(
"kernel.png",
)

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
# register. We remedy this in the function below.
add_missing_registers(pipelined_reduction_graph)

return pipelined_reduction, pipelined_reduction_graph


Expand Down Expand Up @@ -422,16 +452,30 @@ def construct_epilogue(
scheduler.num_stages,
)

existing_get_results: list[GetResult] = sorted(
[x for x in pipelined_reduction.users if isinstance(x, GetResult)],
key=lambda x: x.res_idx,
)
existing_users = {x: x.users for x in existing_get_results}
existing_get_results: list[GetResult] = [
x for x in pipelined_reduction.users if isinstance(x, GetResult)
]
existing_indices = [x.res_idx for x in existing_get_results]

# Map the results from the kernel to the init args (for stages).
for iter_arg, get_result in zip(
reduction.iter_args(reduction_subgraph), existing_get_results
):
# The number of iter args may not be the same as the number of get results
# and so we have to add additional get results for the missing iter args.
iter_args = reduction.iter_args(reduction_subgraph)
for i in range(len(iter_args)):
if i in existing_indices:
continue
with pipelined_reduction.graph.inserting_before(
existing_get_results[0].fx_node.next
):
result = GetResult(pipelined_reduction.fx_node, i).add_to_graph(
pipelined_reduction.graph
)
existing_get_results.append(get_custom(result))

existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx)
existing_users = {x: x.users for x in existing_get_results}

for iter_arg, get_result in zip(iter_args, existing_get_results):
arg_context.map_arg_all(iter_arg, get_result.fx_node)

with pipelined_reduction.graph.inserting_before(
Expand Down
5 changes: 1 addition & 4 deletions iree/turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,14 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
# Initialize initiation interval.
T0 = int(max(self.compute_resource_ii(), self.compute_recurrence_ii(sccs)))

# Compute symbolic all pairs longest path.
e_star_symbolic = all_pairs_longest_paths(self.graph, self.edges)

# Generate the schedule.
# TODO: Come up with a better heuristic on an upper bound for the initiation interval.
T_max_range = 3 * T0
success = False
for T in range(T0, T0 + T_max_range):
logger.debug(f"Trying initiation interval: {T}.")
self.RT = np.zeros((T, len(self.resources)))
self.e_star = evaluate_all_pairs_longest_paths(e_star_symbolic, T)
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T)
logger.debug(f"All Pairs Longest Paths: {self.e_star}.")
self.schedule: dict[fx.Node, int] = {}
for _, scc in topological_sort(sccs).items():
Expand Down
50 changes: 38 additions & 12 deletions iree/turbine/kernel/wave/scheduling/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
get_custom,
CustomOp,
CastOp,
UnaryPyOp,
BinaryPyOp,
ShuffleOp,
Permute,
Extract,
Broadcast,
)
import torch.fx as fx
from enum import Enum
Expand All @@ -23,7 +29,13 @@

# This table contains the number of functional units available for each operation.
def get_available_resources() -> list[int]:
resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS]
resources = [
GLOBAL_MEMORY_UNITS,
SHARED_MEMORY_UNITS,
MMA_UNITS,
VALU_UNITS,
SHUFFLE_UNITS,
]
return np.array([int(subs_idxc(x)) for x in resources])


Expand All @@ -37,8 +49,11 @@ class Operation(Enum):
VALU = "valu"
SALU = "salu"
NOOP = "noop"
SHUFFLE = "shuffle"


SCHEDULING_NOOPS = (IterArg, Permute, Extract, Broadcast, CastOp)

# This table contains the cycles required to execute each operation.
delay_table = {
Operation.READ_SHARED: READ_SHARED_DELAY,
Expand All @@ -47,17 +62,21 @@ class Operation(Enum):
Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY,
Operation.MMA: MMA_DELAY,
Operation.NOOP: 0,
Operation.VALU: VALU_DELAY,
Operation.SHUFFLE: SHUFFLE_DELAY,
}

# This table contains the resource usage for each operation.
# Operations can use more than one resource for more than one cycle.
resource_reservation_table = {
Operation.READ_SHARED: np.array([[0, 1, 0]]),
Operation.WRITE_SHARED: np.array([[0, 1, 0]]),
Operation.READ_GLOBAL: np.array([[1, 0, 0]]),
Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]),
Operation.MMA: np.array([[0, 0, 1]]),
Operation.NOOP: np.array([[0, 0, 0]]),
Operation.READ_SHARED: np.array([[0, 1, 0, 0, 0]]),
Operation.WRITE_SHARED: np.array([[0, 1, 0, 0, 0]]),
Operation.READ_GLOBAL: np.array([[1, 0, 0, 0, 0]]),
Operation.WRITE_GLOBAL: np.array([[1, 0, 0, 0, 0]]),
Operation.MMA: np.array([[0, 0, 1, 0, 0]]),
Operation.NOOP: np.array([[0, 0, 0, 0, 0]]),
Operation.VALU: np.array([[0, 0, 0, 1, 0]]),
Operation.SHUFFLE: np.array([[0, 0, 0, 0, 1]]),
}


Expand All @@ -76,10 +95,12 @@ def get_custom_operation_type(custom: CustomOp) -> Operation:
)
elif isinstance(custom, MMA):
return Operation.MMA
elif isinstance(custom, IterArg):
return Operation.NOOP
elif isinstance(custom, Output):
elif isinstance(custom, SCHEDULING_NOOPS + (Output,)):
return Operation.NOOP
elif isinstance(custom, (UnaryPyOp, BinaryPyOp)):
return Operation.VALU
elif isinstance(custom, ShuffleOp):
return Operation.SHUFFLE
else:
return None

Expand All @@ -106,8 +127,13 @@ def annotate_resource_usage(
)
elif isinstance(custom, MMA):
custom.rrt = resource_reservation_table[Operation.MMA]
elif isinstance(custom, (IterArg, CastOp)):
iter_args.append(node)
elif isinstance(custom, ShuffleOp):
custom.rrt = resource_reservation_table[Operation.SHUFFLE]
elif isinstance(custom, (UnaryPyOp, BinaryPyOp)):
custom.rrt = resource_reservation_table[Operation.VALU]
elif isinstance(custom, SCHEDULING_NOOPS):
if isinstance(custom, IterArg):
iter_args.append(node)
custom.rrt = resource_reservation_table[Operation.NOOP]
elif isinstance(custom, Output):
output = node
Expand Down
Loading

0 comments on commit 3bbe6f8

Please sign in to comment.