Skip to content

Commit

Permalink
Revert "[inductor] benchmark fusion (pytorch#108193)"
Browse files Browse the repository at this point in the history
This reverts commit ec0cdcd.

Reverted pytorch#108193 on behalf of https://github.com/ZainRizvi due to This test is breaking trunk. In the future please make sure to add the ciflow/trunk label before force merging any PR to ensure your code doesn't break those tests ([comment](pytorch#108193 (comment)))
  • Loading branch information
pytorchmergebot committed Oct 26, 2023
1 parent 7da713b commit 485cc0f
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 342 deletions.
126 changes: 0 additions & 126 deletions test/inductor/test_benchmark_fusion.py

This file was deleted.

21 changes: 8 additions & 13 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,14 @@ def __init__(self, name, line):
self.name = name

def __call__(self):
if all(
self.name not in x
for x in (
V.graph.removed_buffers,
V.kernel.removed_buffers,
V.graph.inplaced_to_remove,
V.kernel.inplaced_to_remove,
# V.kernel may be null since this method may be called for the
# wrapper codegen where there is no specific kernel.
if (
self.name
not in (
V.graph.removed_buffers | getattr(V.kernel, "removed_buffers", set())
)
and self.name not in V.graph.inplaced_to_remove
):
return self.line
return None
Expand Down Expand Up @@ -647,10 +647,7 @@ def aliases(self):
if self._buffer_is_marked_removed(inplaced):
continue
for other in inplaced.other_names:
if (
other in V.graph.inplaced_to_remove
or other in V.kernel.inplaced_to_remove
):
if other in V.graph.inplaced_to_remove:
continue
if other in self.input_buffers:
yield self.input_buffers[other], inplaced.inner_name
Expand Down Expand Up @@ -891,8 +888,6 @@ def __init__(self, args=None, increase_kernel_count=True):
self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {}

self.removed_buffers = set()
self.inplaced_to_remove = set()

# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
# the buffer specified by key
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2661,7 +2661,6 @@ def run(kernel):

scalar_kernel = codegen_kernel(CppKernel)
V.graph.removed_buffers |= scalar_kernel.removed_buffers
V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove
self.loop_nest = LoopNestWithSplit.build(scalar_kernel)

if not self.picked_vec_isa:
Expand Down
82 changes: 1 addition & 81 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging
import math
import operator
import os
from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple

import sympy
Expand All @@ -21,14 +20,13 @@
from torch.utils._sympy.value_ranges import ValueRanges
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
from ..codecache import code_hash, get_path, PyCodeCache
from ..codecache import code_hash, get_path
from ..dependencies import MemoryDep, StarDep
from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..scheduler import BaseScheduling
from ..triton_heuristics import AutotuneHint
from ..utils import (
do_bench,
get_fused_kernel_name,
get_kernel_metadata,
green_text,
Expand Down Expand Up @@ -2523,7 +2521,6 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
self.codegen_comment(node_schedule)
kernel.call_kernel(kernel_name)
V.graph.removed_buffers |= kernel.removed_buffers
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove

if config.warn_mix_layout:
kernel.warn_mix_layout(kernel_name)
Expand Down Expand Up @@ -2643,7 +2640,6 @@ def codegen_template(self, template_node, epilogue_nodes):
self.codegen_comment(node_schedule)
kernel.call_kernel(kernel_name, template_node.node)
V.graph.removed_buffers |= kernel.removed_buffers
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
self.scheduler.free_buffers()

def codegen_sync(self):
Expand Down Expand Up @@ -2681,7 +2677,6 @@ def codegen_foreach(self, foreach_node):
if node not in (EnableReduction, DisableReduction):
node.mark_run()
V.graph.removed_buffers |= subkernel.removed_buffers
V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove

src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, [foreach_node])
Expand Down Expand Up @@ -2830,81 +2825,6 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
def flush(self):
pass

def benchmark_fused_nodes(self, nodes):
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
reduction_hint_val, mutations, index_dtype = self.get_kernel_args(
node_schedule, numel, rnumel
)

kernel = TritonKernel(
*tiled_groups,
reduction_hint=reduction_hint_val,
mutations=mutations,
index_dtype=index_dtype,
)

# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
for n in nodes:
n.last_usage = set()

self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): # type: ignore[attr-defined]
src_code = kernel.codegen_kernel()

src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
mod = PyCodeCache.load(src_code)

def cache_file_path():
return os.path.splitext(mod.__file__)[0] + ".kernel_perf" # type: ignore[type-var,operator]

def load_cache():
path = cache_file_path()
if os.path.exists(path):
with open(path) as fd:
return float(fd.read())
return None

def store_cache():
path = cache_file_path()
with open(path, "w") as fd:
fd.write(str(ms))

log.debug(
"kernel src code for %s written to: %s",
{n.get_name() for n in nodes},
mod.__file__,
)
ms = load_cache()
if ms is not None:
return ms

args = mod.get_args()
call = mod.call
wrapped_jit_function = mod.triton_

# call once to trigger the compilation
call(wrapped_jit_function.clone_args(*args))

launchers = wrapped_jit_function.launchers
assert len(launchers) == 1
if launchers[0].n_spills > 0:
# skip benchmarking the kernel if there are register spills
ms = float("inf")
else:
# We have to clone the inplace updated arguments to avoid earlier calls
# generating out of range indices for later calls.
ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)))

log.debug(
"The fused kernel for %s took %.3f ms to run",
{n.get_name() for n in nodes},
ms,
)
store_cache()
return ms


@dataclasses.dataclass
class CandidateTiling:
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@
# For each fused kernel in the wrapper, comment with the nodes that get fused.
# Useful for debugging fusion.
debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"

# how many nodes to allow into a single fusion
max_fusion_size = 64
Expand Down
Loading

0 comments on commit 485cc0f

Please sign in to comment.