diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 352cff11ace17..82f8241d0a226 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -6,8 +6,7 @@ import torch import torch._inductor from torch._dynamo.test_case import run_tests, TestCase -from torch._dynamo.utils import counters -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch._dynamo.utils import counters, optimus_scuba_log try: # importing this will register fbgemm lowerings for inductor @@ -18,11 +17,8 @@ has_fbgemm = False pass -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") - - class MyModule(torch.nn.Module): - def __init__(self, z: int, has_bias: bool, device="cuda") -> None: + def __init__(self, z: int, has_bias: bool, device="cpu") -> None: super().__init__() self.z = z self.device = device @@ -220,7 +216,6 @@ def forward(self, x): return torch.cat(div, dim=1) -@requires_cuda @torch._inductor.config.patch( pre_grad_fusion_options={ "batch_linear": {}, @@ -271,8 +266,8 @@ def test_group_linear_fusion(self): z = 10 for has_bias in [True, False]: counters.clear() - module = MyModule(z, has_bias).to("cuda") - input = [torch.randn(z, z, device="cuda")] + module = MyModule(z, has_bias).to("cpu") + input = [torch.randn(z, z, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -285,6 +280,7 @@ def test_group_linear_fusion(self): counters["inductor"]["batch_fusion"], 0, ) + self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced) @@ -297,13 +293,14 @@ def test_group_linear_fusion(self): counters["inductor"]["batch_fusion"], 3, ) + self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log) counters.clear() @unittest.skipIf(not has_fbgemm, "requires fbgemm") def test_group_linear_fusion_different_shapes(self): counters.clear() - module = MyModule2().eval().to("cuda") - input = [torch.rand(4, 24, device="cuda")] + module = MyModule2().eval().to("cpu") + input = [torch.rand(4, 24, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -334,8 +331,8 @@ def test_batch_layer_norm_fusion(self): for has_weight in [True, False]: for has_bias in [True, False]: counters.clear() - module = MyModule3("cuda", has_weight, has_bias).to("cuda") - input = [torch.randn(2, 5, 50, device="cuda")] + module = MyModule3("cpu", has_weight, has_bias).to("cpu") + input = [torch.randn(2, 5, 50, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -363,8 +360,8 @@ def test_batch_linear_lhs_fusion(self): z = 10 for has_bias in [True, False]: counters.clear() - module = MyModule4(z, "cuda", has_bias) - input = [torch.randn(20, z, device="cuda")] + module = MyModule4(z, "cpu", has_bias) + input = [torch.randn(20, z, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -387,8 +384,8 @@ def test_batch_linear_lhs_fusion(self): def test_batch_linear_pre_grad_fusion(self): for has_bias in [True, False]: counters.clear() - module = MyModule5("cuda", has_bias) - input = [torch.randn(50, 500, device="cuda")] + module = MyModule5("cpu", has_bias) + input = [torch.randn(50, 500, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -411,8 +408,8 @@ def test_batch_linear_pre_grad_fusion(self): def test_pointwise_op_fusion(self): counters.clear() - module = TestPoitwiseOps("cuda") - input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] + module = TestPoitwiseOps("cpu") + input = [torch.randn(50, 1000, requires_grad=True, device="cpu")] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -450,16 +447,15 @@ def forward(self, inputs): return output -@requires_cuda @torch._inductor.config.patch( post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}} ) class TestPostGradBatchLinearFusion(TestCase): def test_batch_linear_post_grad_fusion(self): - pt1_module = TestBMMFusionModule().cuda() + pt1_module = TestBMMFusionModule() inputs = [] for _ in range(10): - inputs.append(torch.randn(10, 10).cuda()) + inputs.append(torch.randn(10, 10)) eager_output = pt1_module(inputs) pt2_module = torch.compile(pt1_module) pt2_output = pt2_module(inputs) @@ -468,6 +464,8 @@ def test_batch_linear_post_grad_fusion(self): counters["inductor"]["batch_fusion"], 2, ) + self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log) + self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log) class TestFindIndependentSubsetGreedy(TestCase): diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 6d532ece14015..c3fc7ecb9ce59 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -2,7 +2,7 @@ import torch from torch._dynamo.test_case import run_tests, TestCase -from torch._dynamo.utils import counters +from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import HAS_CUDA @@ -90,6 +90,10 @@ def cm_with_list(x): counters["inductor"]["split_cat_norm"], expected_split_norm_count, ) + if expected_split_norm_count > 0: + self.assertIn( + "split_cat_pattern_normalization_pass_pre_grad", optimus_scuba_log + ) counters.clear() @patch @@ -251,6 +255,10 @@ def split_getitem_out_of_order(x): counters["inductor"]["consecutive_split_merged"], expected_split_merged, ) + if expected_split_merged > 0: + self.assertIn( + "split_cat_pattern_merge_splits_pass_pre_grad", optimus_scuba_log + ) counters.clear() @patch @@ -1063,6 +1071,7 @@ def stack_tahn_unbind(x): counters["inductor"]["stack_tahn_unbind_merged"], expected_stack_tahn_unbind_merged, ) + self.assertIn("split_cat_pattern_merge_getitem_cat_pass_pre_grad", optimus_scuba_log) counters.clear() def test_numpy_compat_normalization(self): @@ -1087,7 +1096,7 @@ def test_stack_normalization_axis_kwarg(self): def fn(x, y): return torch.stack([x, y], axis=1) - x, y = (torch.rand((4, 4), device="cuda") for _ in range(2)) + x, y = (torch.rand((4, 4), device="cpu") for _ in range(2)) expected = fn(x, y) actual = torch.compile(fn)(x, y) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 95751bf32c9d4..3eaba129afbe6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -101,6 +101,7 @@ counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) +optimus_scuba_log: Dict[str, Any] = {} troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html" nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html" nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." @@ -1154,10 +1155,7 @@ def dict_keys_repr(const_keys, *, local) -> str: GLOBAL_KEY_PREFIX = "__dict_key" -from torch._subclasses import ( # noqa: F401 - FakeTensorMode, - UnsupportedFakeTensorException, -) +from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 def wrap_fake_exception(fn): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 2b7c6305f4927..288cac990b25b 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -33,13 +33,19 @@ logging as dynamo_logging, utils as dynamo_utils, ) -from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code +from torch._dynamo.utils import ( + counters, + detect_fake_mode, + lazy_format_graph_code, + optimus_scuba_log, +) from torch._functorch.aot_autograd import aot_export_module, make_boxed_func from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache from torch._inductor.debug import save_args_for_compile_fx_inner from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import signpost_event from torch.fx.passes.fake_tensor_prop import FakeTensorProp from .._dynamo.backends.common import aot_autograd @@ -621,9 +627,11 @@ def fx_codegen_and_compile( post_grad_passes(gm, is_inference=is_inference) V.debug.fx_graph_transformed(gm, example_inputs) post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm)) - log.debug( - "counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s", - counters["inductor"], + optimus_scuba_log["inductor_post_grad"] = counters["inductor"] + signpost_event( + "optimus", + "compile_fx.post_grad_passes", + optimus_scuba_log, ) with V.set_fake_mode(fake_mode): @@ -1159,9 +1167,11 @@ def compile_fx( ) model_ = pre_grad_passes(model_, example_inputs_) - log.debug( - "counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s", - counters["inductor"], + optimus_scuba_log["inductor_pre_grad"] = counters["inductor"] + signpost_event( + "optimus", + "compile_fx.pre_grad_passes", + optimus_scuba_log, ) if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 8eab82f36d389..3aec31b503449 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -17,7 +17,6 @@ import torch from torch._dynamo.utils import counters -from torch._utils_internal import print_graph from .. import config from ..pattern_matcher import ( @@ -936,7 +935,6 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True): def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): - print_graph(graph, "Before group_batch fusion in pre grad pass.") fusions: List[GroupBatchFusionBase] = [] # we keep all current pre grad fusions to keep # current implementation, will remove this later @@ -965,4 +963,3 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): for rule in fusions: apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] - print_graph(graph, f"Apply fusion {rule.__class__.__name__}.") diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index 98f8431f47268..b4baf12d4eae3 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -8,7 +8,6 @@ import torch import torch.optim as optim -from torch._utils_internal import print_graph from .. import config @@ -43,8 +42,8 @@ def clean_memory() -> None: def compare_dict_tensors(dict_base, dict_control, precision): if len(set(dict_base.keys())) != len(set(dict_control.keys())): logger.warning("Mismatch keys found before and after pre/post grad fx passes.") - print_graph(dict_base.keys(), "keys before pre/post grad fx passes.") - print_graph(dict_control.keys(), "keys after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) return False is_allclose = True for key in dict_base.keys(): @@ -66,8 +65,8 @@ def compare_dict_tensors(dict_base, dict_control, precision): logger.warning( "Mismatch parameter values found before and after pre/post grad fx passes." ) - print_graph(dict_base[key], "value before pre/post grad fx passes.") - print_graph(dict_control[key], "value after pre/post grad fx passes.") + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) is_allclose = False return is_allclose @@ -92,9 +91,11 @@ def compare_tuple_tensors(tuple_base, tuple_control, precision): atol=precision, equal_nan=True, ): - print_graph(tuple_base[i], "forward output before pre/post grad fx passes.") - print_graph( - tuple_control[i], "forward output after pre/post grad fx passes." + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] ) is_allclose = False return is_allclose diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 9f160be82395c..88194c5fb142d 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1,3 +1,4 @@ +import copy import functools import itertools import logging @@ -12,10 +13,11 @@ import torch.utils._pytree as pytree from torch import fx from torch._decomp import register_decomposition +from torch._dynamo.utils import counters, optimus_scuba_log from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype -from torch._utils_internal import print_graph +from torch._utils_internal import upload_graph from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq from .. import config, ir, pattern_matcher @@ -80,18 +82,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): if config.pattern_matcher: lazy_init() - - print_graph(gm.graph, "Before group batch fusion in post grad pass.") + inductor_before_change = copy.deepcopy(counters["inductor"]) group_batch_fusion_passes(gm.graph, pre_grad=False) - print_graph(gm.graph, "After group batch fusion in post grad pass.") + if counters["inductor"] != inductor_before_change: + optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph) remove_noop_ops(gm.graph) - print_graph(gm.graph, "Before split cat in post grad pass.") for patterns in pass_patterns: patterns.apply(gm.graph) # type: ignore[arg-type] - print_graph( - gm.graph, - "Apply split cat pattern matcher PatternMatcherPass in post grad.", - ) if is_inference: inference_patterns.apply(gm.graph) # type: ignore[arg-type] @@ -112,8 +109,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): gm.recompile() gm.graph.lint() - print_graph(gm.graph, "After recompile in post grad pass.") - @init_once_fakemode def lazy_init(): diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 17c55364284bc..8a4d0a1e338a1 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from torch._dynamo.utils import detect_fake_mode -from torch._utils_internal import print_graph +from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log +from torch._utils_internal import upload_graph from torch.fx.experimental.optimization import ( matches_module_pattern, replace_node_module, @@ -13,6 +13,7 @@ from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + from .. import config from ..fx_utils import matches_module_function_pattern @@ -27,13 +28,27 @@ log = logging.getLogger(__name__) -normalization_pass = PatternMatcherPass(prevent_match_across_mutations=True) -merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True) -split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True) -unbind_stack_pass = PatternMatcherPass(prevent_match_across_mutations=True) -efficient_conv_bn_eval_pass = PatternMatcherPass(prevent_match_across_mutations=True) -merge_getitem_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True) -predispatch_pass = PatternMatcherPass(prevent_match_across_mutations=True) +normalization_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="normalization_pass" +) +merge_splits_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="merge_splits_pass" +) +split_cat_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="split_cat_pass" +) +unbind_stack_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="unbind_stack_pass" +) +efficient_conv_bn_eval_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass" +) +merge_getitem_cat_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass" +) +predispatch_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="predispatch_pass" +) # based on predispatch aten IR normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) @@ -114,17 +129,21 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs): f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}", ) else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ gm = fuse_fx(gm, example_inputs) numpy_compat_normalization(gm.graph) - print_graph(gm.graph, "Before group batch fusion in pre grad pass.") + inductor_before_change = copy.deepcopy(counters["inductor"]) group_batch_fusion_passes(gm.graph, pre_grad=True) - print_graph(gm.graph, "Before split cat in pre grad pass.") + if counters["inductor"] != inductor_before_change: + optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph(gm.graph) for pattern_matcher_pass in pattern_matcher_passes: + inductor_before_change = copy.deepcopy(counters["inductor"]) pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] - print_graph( - gm.graph, - "Apply split cat pattern matcher PatternMatcherPass in pre grad.", - ) + if counters["inductor"] != inductor_before_change: + optimus_scuba_log[ + f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad" + ] = upload_graph(gm.graph) if config.pre_grad_custom_pass is not None: config.pre_grad_custom_pass(gm.graph) @@ -148,8 +167,6 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs): config.fx_passes_numeric_check.get("precision", 1e-4), ) - print_graph(gm.graph, "After recompile in pre grad pass.") - return gm diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 4beedab8c5d45..c48251fc352d6 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1214,12 +1214,15 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule): class PatternMatcherPass: - def __init__(self, prevent_match_across_mutations=False): + def __init__( + self, prevent_match_across_mutations=False, pass_name: Optional[str] = None + ): super().__init__() self.patterns: DefaultDict[ torch.fx.node.Target, List[PatternEntry] ] = defaultdict(list) self.prevent_match_across_mutations = prevent_match_across_mutations + self.pass_name = pass_name def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]: return self.patterns[item] diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 170cb12dfe10c..dd5ecffa50496 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -83,7 +83,7 @@ def log_compilation_event(metrics): log.info("%s", metrics) -def print_graph(graph, msg: str): +def upload_graph(graph): pass