Skip to content

Commit

Permalink
Add support for scheduling attention
Browse files Browse the repository at this point in the history
This PR adds support for scheduling in
attention operators.
  • Loading branch information
harsh-nod committed Nov 7, 2024
1 parent db1ec57 commit 176cffa
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 42 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
MMA_DELAY = index_symbol("$MMA_DELAY")
VALU_DELAY = index_symbol("$VALU_DELAY")
SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY")
SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS")
GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS")
MMA_UNITS = index_symbol("$MMA_UNITS")
VALU_UNITS = index_symbol("$VALU_UNITS")
SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS")
78 changes: 75 additions & 3 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataclasses import dataclass
import sympy
import math
from functools import partial
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")

Expand Down Expand Up @@ -157,7 +159,25 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]]
return circuits


def all_pairs_longest_paths(
def all_pairs_longest_paths_helper(
graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int
):
v = list(graph.nodes)[i]
for w in graph.nodes:
dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)])
return v, dist


def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int):
"""
This function is called once for a different value of i.
"""
for j in range(N):
D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j])
return i, D[i]


def all_pairs_longest_paths_symbolic(
graph: fx.Graph,
edges: list[Edge],
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
Expand All @@ -181,6 +201,51 @@ def all_pairs_longest_paths(
return D


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Uses the Floyd-Warshall algorithm and assumes that the cycles don't
have positive weights. This function computes the distances in parallel
by parallelizing across the start nodes.
"""
N = len(graph.nodes)
D = np.zeros((N, N), dtype=np.float32)
negative_inf = -np.inf
for i in range(N):
for j in range(N):
D[i, j] = negative_inf

all_nodes = list(graph.nodes)
for edge in edges:
i = all_nodes.index(edge._from)
j = all_nodes.index(edge._to)
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
for i, from_node in enumerate(graph.nodes):
for j, to_node in enumerate(graph.nodes):
if np.isinf(D[i, j]) or i == j:
continue
G[(from_node, to_node)] = int(D[i, j])

return G


def evaluate_all_pairs_longest_paths(
D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int
) -> dict[tuple[fx.Node, fx.Node], int]:
Expand All @@ -190,7 +255,8 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
D_static[key] = D_static[key].subs(T, initiation_interval)
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down Expand Up @@ -244,8 +310,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight:
weight = EdgeWeight(0, delay_table[Operation.MMA])
case IterArg():
weight = EdgeWeight(1, 0)
case CastOp():
case CastOp() | Extract() | Permute() | Broadcast() | Reshape():
weight = EdgeWeight(0, delay_table[Operation.NOOP])
case UnaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case BinaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case ShuffleOp():
weight = EdgeWeight(0, delay_table[Operation.SHUFFLE])
case _:
raise ValueError(f"Unsupported node type: {custom_node}")
weight.delay = subs_idxc(weight.delay)
Expand Down
94 changes: 76 additions & 18 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
GetResult,
get_custom,
SchedulingGroupBarrier,
MMA,
NewRegister,
)
from .modulo_scheduling import ModuloScheduler
from ..utils import (
Expand Down Expand Up @@ -99,10 +101,11 @@ def add_nodes_by_schedule(
# Set the index for the new node by substituting the induction variable
# for the current iteration.
new_node.index = node.index
for dim in new_node.index:
new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
if new_node.index:
for dim in new_node.index:
new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
# Add scheduling parameters for debugging.
new_node.scheduling_parameters = node.scheduling_parameters
# Update the rotating registers and argument context for the current node (if applicable).
Expand All @@ -117,21 +120,25 @@ def add_nodes_by_schedule(
# Update the init args in the argument context whenever a result is computed.
if node in arg_context.results:
if (
pipelining_stage == PipelineStage.KERNEL
or pipelining_stage == PipelineStage.EPILOGUE
pipelining_stage == PipelineStage.EPILOGUE
or pipelining_stage == PipelineStage.KERNEL
):
logger.debug(
f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}."
)
arg_context.map_arg_all(
arg_context.result_to_iter_arg[node], new_node.fx_node
arg_context.map_arg_all_after_iteration(
arg_context.result_to_iter_arg[node],
new_node.fx_node,
iteration,
)
if pipelining_stage == PipelineStage.PROLOGUE:
logger.debug(
f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}."
)
arg_context.map_arg_all(
arg_context.result_to_init_arg[node], new_node.fx_node
arg_context.map_arg_all_after_iteration(
arg_context.result_to_init_arg[node],
new_node.fx_node,
iteration,
)

if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers:
Expand All @@ -154,6 +161,21 @@ def push_placeholders(
arg_context.map_arg_all(node, root_node)


def add_missing_registers(graph: fx.Graph):
""" """
for node in graph.nodes:
custom = get_custom(node)
if isinstance(custom, MMA):
acc = get_custom(custom.acc)
if acc.graph != custom.graph:
with custom.graph.inserting_before(node):
register = NewRegister(
acc.shape, acc.dtype, acc.value
).add_to_graph(custom.graph)
register.index = acc.index
custom.update_arg("acc", register)


def construct_prologue(
reduction_subgraph: fx.Graph,
reduction: Reduction,
Expand Down Expand Up @@ -213,6 +235,12 @@ def construct_prologue(
new_init_args.append(mapped_init_arg)
reduction.init_args = new_init_args

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
# register. We remedy this in the function below.
add_missing_registers(reduction.graph)


def flatten_dict_values(
rotating_registers: dict[fx.Node, list[fx.Node]]
Expand Down Expand Up @@ -386,6 +414,12 @@ def construct_kernel(
"kernel.png",
)

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
# register. We remedy this in the function below.
add_missing_registers(pipelined_reduction_graph)

return pipelined_reduction, pipelined_reduction_graph


Expand Down Expand Up @@ -422,16 +456,34 @@ def construct_epilogue(
scheduler.num_stages,
)

existing_get_results: list[GetResult] = sorted(
[x for x in pipelined_reduction.users if isinstance(x, GetResult)],
key=lambda x: x.res_idx,
)
existing_users = {x: x.users for x in existing_get_results}
existing_get_results: list[GetResult] = [
x for x in pipelined_reduction.users if isinstance(x, GetResult)
]
existing_indices = [x.res_idx for x in existing_get_results]

# Map the results from the kernel to the init args (for stages).
for iter_arg, get_result in zip(
reduction.iter_args(reduction_subgraph), existing_get_results
):
# The number of iter args may not be the same as the number of get results
# and so we have to add additional get results for the missing iter args.
# This happens if some of the iter args have no uses outside the reduction
# (such as the max value in flash attention). While they may not have any
# uses in the original reduction, they will have uses in the pipelined
# reduction outside the reduction and so need to be added in the correct order.
iter_args = reduction.iter_args(reduction_subgraph)
for i in range(len(iter_args)):
if i in existing_indices:
continue
with pipelined_reduction.graph.inserting_before(
existing_get_results[0].fx_node.next
):
result = GetResult(pipelined_reduction.fx_node, i).add_to_graph(
pipelined_reduction.graph
)
existing_get_results.append(get_custom(result))

existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx)
existing_users = {x: x.users for x in existing_get_results}

for iter_arg, get_result in zip(iter_args, existing_get_results):
arg_context.map_arg_all(iter_arg, get_result.fx_node)

with pipelined_reduction.graph.inserting_before(
Expand Down Expand Up @@ -474,6 +526,12 @@ def construct_epilogue(
for i, get_result in enumerate(existing_get_results):
replace_uses_in(existing_users, get_result, new_results[i])

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
# register. We remedy this in the function below.
add_missing_registers(pipelined_reduction.graph)

if visualize:
visualize_mapped_graphs(
pipelined_reduction.graph,
Expand Down
22 changes: 22 additions & 0 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None:
for stage in range(self.num_stages):
self.argument_map[iteration][stage][from_] = to_

def map_arg_all_after_iteration(
self, from_: fx.Node, to_: fx.Node, iteration: int
) -> None:
"""
Maps the given argument from one to another into the argument context for all stages
after the specified iteration.
"""
for iteration in range(iteration + 1, self.num_iterations):
for stage in range(self.num_stages):
self.argument_map[iteration][stage][from_] = to_

def map_arg_all_before_iteration(
self, from_: fx.Node, to_: fx.Node, iteration: int
) -> None:
"""
Maps the given argument from one to another into the argument context for all stages
before the specified iteration.
"""
for iteration in range(0, iteration):
for stage in range(self.num_stages):
self.argument_map[iteration][stage][from_] = to_

def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None:
"""
Maps the given argument from one to another into the argument context for all stages
Expand Down
5 changes: 1 addition & 4 deletions iree/turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,14 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
# Initialize initiation interval.
T0 = int(max(self.compute_resource_ii(), self.compute_recurrence_ii(sccs)))

# Compute symbolic all pairs longest path.
e_star_symbolic = all_pairs_longest_paths(self.graph, self.edges)

# Generate the schedule.
# TODO: Come up with a better heuristic on an upper bound for the initiation interval.
T_max_range = 3 * T0
success = False
for T in range(T0, T0 + T_max_range):
logger.debug(f"Trying initiation interval: {T}.")
self.RT = np.zeros((T, len(self.resources)))
self.e_star = evaluate_all_pairs_longest_paths(e_star_symbolic, T)
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T)
logger.debug(f"All Pairs Longest Paths: {self.e_star}.")
self.schedule: dict[fx.Node, int] = {}
for _, scc in topological_sort(sccs).items():
Expand Down
Loading

0 comments on commit 176cffa

Please sign in to comment.