diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py deleted file mode 100644 index 276e0ef4de196..0000000000000 --- a/test/inductor/test_benchmark_fusion.py +++ /dev/null @@ -1,126 +0,0 @@ -# Owner(s): ["module: inductor"] -import math -import os -import sys - -import torch -from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TestCase as TorchTestCase, -) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA - -# Make the helper files in test/ importable -pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -sys.path.append(pytorch_test_dir) - -import contextlib -import unittest - -from torch._inductor import config -from torch._inductor.scheduler import Scheduler - -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests - - -class TestCase(TorchTestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._stack = contextlib.ExitStack() - cls._stack.enter_context( - config.patch( - { - "benchmark_kernel": True, - "benchmark_fusion": True, - } - ) - ) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - -class BenchmarkFusionTestTemplate: - def test_softmax(self): - def f(x): - return torch.nn.functional.softmax(x, dim=-1) - - self.common(f, (torch.rand(2, 8192),)) - - def test_resnet18(self): - import torchvision - - model = torchvision.models.resnet18() - model.eval() - batch_size = 16 - inputs = (torch.randn((batch_size, 3, 224, 224)),) - self.common(model, inputs, atol=1e-2, rtol=1e-2) - - def test_register_spills(self): - """ - The test can potentially trigger register spills - """ - old_benchmark_fn = Scheduler.benchmark_fused_nodes - - def new_benchmark_fn(scheduler, nodes): - """ - We override Scheduler.benchmark_fused_nodes to return latency 1.0 - if there are no register spills. Without this, we may not able to - test the code path handling register spilling because before register - start spilling, the related fusion may have already been skipped - due to longer lantency. - """ - ms = old_benchmark_fn(scheduler, nodes) - if not math.isinf(ms): - ms = 1.0 - return ms - - # Disable dynamic_scale_rblock to make it easier to trigger register - # spilling. - with unittest.mock.patch.object( - Scheduler, "benchmark_fused_nodes", new_benchmark_fn - ), config.patch("dynamic_scale_rblock", False): - S = 512 - - def f(*inputs): - inputs = list(inputs) - outputs = [] - out = torch.zeros(S, device=self.device) - for x in inputs: - x = x * 2 - x = x + 1 - x = x.sum(dim=-1) - outputs.append(x) - out = out + x - return outputs, out - - N = int(os.environ.get("NINP", "30")) - inputs = [torch.randn(S, 2560, device=self.device) for _ in range(N)] - opt_f = torch.compile(f) - opt_f(*inputs) - - -if HAS_CUDA and not TEST_WITH_ASAN: - - class BenchmarkFusionCudaTest(TestCase): - common = check_model_cuda - device = "cuda" - - copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") - -if HAS_CPU and not torch.backends.mps.is_available(): - - class BenchmarkFusionCpuTest(TestCase): - common = check_model - device = "cpu" - - copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCpuTest, "cpu") - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - if HAS_CPU or HAS_CUDA: - run_tests() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f0ca422461517..9ab92e22146a7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -438,14 +438,14 @@ def __init__(self, name, line): self.name = name def __call__(self): - if all( - self.name not in x - for x in ( - V.graph.removed_buffers, - V.kernel.removed_buffers, - V.graph.inplaced_to_remove, - V.kernel.inplaced_to_remove, + # V.kernel may be null since this method may be called for the + # wrapper codegen where there is no specific kernel. + if ( + self.name + not in ( + V.graph.removed_buffers | getattr(V.kernel, "removed_buffers", set()) ) + and self.name not in V.graph.inplaced_to_remove ): return self.line return None @@ -647,10 +647,7 @@ def aliases(self): if self._buffer_is_marked_removed(inplaced): continue for other in inplaced.other_names: - if ( - other in V.graph.inplaced_to_remove - or other in V.kernel.inplaced_to_remove - ): + if other in V.graph.inplaced_to_remove: continue if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name @@ -891,8 +888,6 @@ def __init__(self, args=None, increase_kernel_count=True): self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {} self.removed_buffers = set() - self.inplaced_to_remove = set() - # key: the buffer to write # value: the buffer to read and whose memory can be reused for # the buffer specified by key diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ccb242cfb73b2..7a1aedf6413a7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2661,7 +2661,6 @@ def run(kernel): scalar_kernel = codegen_kernel(CppKernel) V.graph.removed_buffers |= scalar_kernel.removed_buffers - V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove self.loop_nest = LoopNestWithSplit.build(scalar_kernel) if not self.picked_vec_isa: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f486e63f1cd79..c3d05eb3e3501 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -8,7 +8,6 @@ import logging import math import operator -import os from typing import Any, Counter, Dict, Iterable, List, Optional, Set, Tuple import sympy @@ -21,14 +20,13 @@ from torch.utils._sympy.value_ranges import ValueRanges from ..._dynamo.utils import counters from .. import config, ir, scheduler -from ..codecache import code_hash, get_path, PyCodeCache +from ..codecache import code_hash, get_path from ..dependencies import MemoryDep, StarDep from ..ir import IRNode, ReductionHint, TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction from ..scheduler import BaseScheduling from ..triton_heuristics import AutotuneHint from ..utils import ( - do_bench, get_fused_kernel_name, get_kernel_metadata, green_text, @@ -2523,7 +2521,6 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name) V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove if config.warn_mix_layout: kernel.warn_mix_layout(kernel_name) @@ -2643,7 +2640,6 @@ def codegen_template(self, template_node, epilogue_nodes): self.codegen_comment(node_schedule) kernel.call_kernel(kernel_name, template_node.node) V.graph.removed_buffers |= kernel.removed_buffers - V.graph.inplaced_to_remove |= kernel.inplaced_to_remove self.scheduler.free_buffers() def codegen_sync(self): @@ -2681,7 +2677,6 @@ def codegen_foreach(self, foreach_node): if node not in (EnableReduction, DisableReduction): node.mark_run() V.graph.removed_buffers |= subkernel.removed_buffers - V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, [foreach_node]) @@ -2830,81 +2825,6 @@ def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): def flush(self): pass - def benchmark_fused_nodes(self, nodes): - _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group - node_schedule = self.generate_node_schedule(nodes, numel, rnumel) - tiled_groups = self.select_tiling(node_schedule, numel, rnumel) - reduction_hint_val, mutations, index_dtype = self.get_kernel_args( - node_schedule, numel, rnumel - ) - - kernel = TritonKernel( - *tiled_groups, - reduction_hint=reduction_hint_val, - mutations=mutations, - index_dtype=index_dtype, - ) - - # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. - for n in nodes: - n.last_usage = set() - - self.codegen_node_schedule_with_kernel(node_schedule, kernel) - with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): # type: ignore[attr-defined] - src_code = kernel.codegen_kernel() - - src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") - mod = PyCodeCache.load(src_code) - - def cache_file_path(): - return os.path.splitext(mod.__file__)[0] + ".kernel_perf" # type: ignore[type-var,operator] - - def load_cache(): - path = cache_file_path() - if os.path.exists(path): - with open(path) as fd: - return float(fd.read()) - return None - - def store_cache(): - path = cache_file_path() - with open(path, "w") as fd: - fd.write(str(ms)) - - log.debug( - "kernel src code for %s written to: %s", - {n.get_name() for n in nodes}, - mod.__file__, - ) - ms = load_cache() - if ms is not None: - return ms - - args = mod.get_args() - call = mod.call - wrapped_jit_function = mod.triton_ - - # call once to trigger the compilation - call(wrapped_jit_function.clone_args(*args)) - - launchers = wrapped_jit_function.launchers - assert len(launchers) == 1 - if launchers[0].n_spills > 0: - # skip benchmarking the kernel if there are register spills - ms = float("inf") - else: - # We have to clone the inplace updated arguments to avoid earlier calls - # generating out of range indices for later calls. - ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args))) - - log.debug( - "The fused kernel for %s took %.3f ms to run", - {n.get_name() for n in nodes}, - ms, - ) - store_cache() - return ms - @dataclasses.dataclass class CandidateTiling: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index dd9b942d09eec..287692589d52c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -198,7 +198,6 @@ # For each fused kernel in the wrapper, comment with the nodes that get fused. # Useful for debugging fusion. debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" -benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" # how many nodes to allow into a single fusion max_fusion_size = 64 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 40762a5dce23a..1ed30aeb83837 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3,7 +3,6 @@ import functools import itertools import logging -import math import os import pprint import textwrap @@ -29,8 +28,6 @@ get_device_tflops, get_dtype_size, get_gpu_dram_gbps, - green_text, - red_text, sympy_product, ) from .virtualized import V @@ -1498,97 +1495,6 @@ def fuse_nodes(self): fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) break - def benchmark_fused_nodes(self, nodes): - """ - Benchmark fused list of nodes and return the execution time - in milliseconds on randomly generated inputs. - """ - assert len(nodes) > 0 - device = nodes[0].get_device() - V.graph.scheduler = self - self.current_device = device - backend = self.get_backend(device) - return backend.benchmark_fused_nodes(nodes) - - def speedup_by_fusion(self, node1, node2): - """ - If config.benchmark_fusion is False, always return True. - Otherwise, return True if fusion can brings speedup. - """ - if not config.benchmark_fusion: - return True - - if node1.is_template(): - # TODO support benchmarking epilogue fusion - return True - - node_list_1 = node1.get_nodes() - device = node_list_1[0].get_device() - - # don't support benchmark fusion for CPU right now. - if device.type == "cpu": - return True - - node_list_2 = node2.get_nodes() - node_list_fused = node_list_1 + node_list_2 - - # We can not accurately benchmark kernel using atomic_add - # due to how we generate random integer inputs. - # Skip benchmarking them by allowing fusion. - if any( - hasattr(n.node, "data") - and hasattr(n.node.data, "scatter_mode") - and n.node.data.scatter_mode == "atomic_add" - for n in node_list_fused - ): - return True - - from triton.compiler.errors import CompilationError - - try: - ms1 = self.benchmark_fused_nodes(node_list_1) - if math.isinf(ms1): - log.debug( - "Skip fusion because of register spilling of the first kernel" - ) - return False - ms2 = self.benchmark_fused_nodes(node_list_2) - if math.isinf(ms2): - log.debug( - "Skip fusion because of register spilling of the second kernel" - ) - return False - ms_fused = self.benchmark_fused_nodes(node_list_fused) - if math.isinf(ms_fused): - log.debug( - "Skip fusion because of register spilling of the fused kernel" - ) - return False - except CompilationError as e: - # workaround triton issue: https://github.com/openai/triton/issues/2151 - if "Loop-carried variable" in str(e): - return True # allow fusion - else: - raise - - if log.isEnabledFor(logging.DEBUG): - if ms_fused < ms1 + ms2: - log.debug( - "Fusing %s with %s cause %sx speedup", - node1.get_names(), - node2.get_names(), - green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), - ) - else: - log.debug( - "Fusing %s with %s cause %sx slowdown", - node1.get_names(), - node2.get_names(), - red_text(f"{ms_fused / (ms1 + ms2):.3f}"), - ) - - return ms_fused < ms1 + ms2 - def fuse_nodes_once(self): """ Mutates self.nodes to combine nodes into FusedSchedulerNodes. @@ -1604,8 +1510,6 @@ def fuse_nodes_once(self): if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( node1, node2 ): - if not self.speedup_by_fusion(node1, node2): - continue node3 = fuse(node1, node2) fused_nodes.remove(node1) fused_nodes.remove(node2) @@ -1982,7 +1886,7 @@ def remove_filter(n): remove = all(n in names_to_remove for n in buf.other_names) if remove: self.remove_inplace_buffer(name) - V.kernel.inplaced_to_remove.add(name) + V.graph.inplaced_to_remove.add(name) else: self.remove_buffer(name) @@ -2184,10 +2088,3 @@ def flush(self): Flush the generated kernel and python wrapper code to the source code file. """ raise NotImplementedError() - - def benchmark_fused_nodes(self, nodes): - """ - Benchmark fused list of nodes and return the execution time - in milliseconds on randomly generated inputs. - """ - raise NotImplementedError() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 6e981f9225fa9..593e949538293 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -56,21 +56,6 @@ class NullHandler: pass -class NullKernelHandler(NullHandler): - """ - We need access `V.kernel.removed_buffers` in DeferredLine class when there - is no kernel in the context. This happens when codegening the wrapper. - Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't - need call 'getattr' with default value which is error prone to typo in - attribute name. - """ - - def __init__(self): - super().__init__() - self.removed_buffers = set() - self.inplaced_to_remove = set() - - def _arg_str(a) -> str: if isinstance(a, sympy.Expr): return sympy_str(a) @@ -184,7 +169,7 @@ def __getattr__(self, item): _graph = Virtualized("graph", NullHandler) _real_inputs = Virtualized("real_inputs", NullHandler) _fake_mode = Virtualized("fake_mode", NullHandler) -_kernel = Virtualized("kernel", NullKernelHandler) +_kernel = Virtualized("kernel", NullHandler) _debug = Virtualized("debug", NullHandler) _interpreter = Virtualized("interpreter", NullHandler) _aot_compilation = Virtualized("aot_compilation", NullHandler)