Skip to content

Commit

Permalink
[PT2][Optimus][Observability] Log the optimus graph transformation to…
Browse files Browse the repository at this point in the history
… the scuba (pytorch#119745)

Summary:
X-link: pytorch/benchmark#2163


Current everstore upload logging may cuase excessive compilation time when the model has lots of graph breaks (post: https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/), we here log the transformation only when the graph changed

Test Plan:
# unit test
```
buck2 test //caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/8dc77120-76ea-41bf-83b4-ada88bf1e2c8
Test UI: https://www.internalfb.com/intern/testinfra/testrun/7036874622840948
Network: Up: 321KiB  Down: 33MiB  (reSessionID-5372a58f-8b66-43a9-82b4-ee23aeaa44b9)
Jobs completed: 20. Time elapsed: 4:13.4s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0

```
buck2 test //caffe2/test/inductor:split_cat_fx_passes
```
Buck UI: https://www.internalfb.com/buck2/abc8b8f8-d240-47d3-ad9d-cae13b8e62d3
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229323926967362
Network: Up: 119KiB  Down: 66KiB  (reSessionID-8bdd34b3-159b-469f-9f00-8384620c13ea)
Jobs completed: 28. Time elapsed: 3:02.9s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 11. Fail 0. Fatal 0. Skip 0. Build failure 0



# e2e
baseline:
f528209775
proposal:
f531285723

scuba: https://fburl.com/scuba/workflow_signpost/7hamzr64

 {F1456548774}

Reviewed By: jackiexu1992

Differential Revision: D53692344
  • Loading branch information
mengluy0125 authored and facebook-github-bot committed Feb 15, 2024
1 parent 4730022 commit ecc6650
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 54 deletions.
6 changes: 5 additions & 1 deletion test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch._inductor
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._dynamo.utils import counters, optimus_scuba_log
from torch.testing._internal.inductor_utils import HAS_CUDA

try:
Expand Down Expand Up @@ -285,6 +285,7 @@ def test_group_linear_fusion(self):
counters["inductor"]["batch_fusion"],
0,
)
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
Expand All @@ -297,6 +298,7 @@ def test_group_linear_fusion(self):
counters["inductor"]["batch_fusion"],
3,
)
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)
counters.clear()

@unittest.skipIf(not has_fbgemm, "requires fbgemm")
Expand Down Expand Up @@ -468,6 +470,8 @@ def test_batch_linear_post_grad_fusion(self):
counters["inductor"]["batch_fusion"],
2,
)
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)


class TestFindIndependentSubsetGreedy(TestCase):
Expand Down
13 changes: 12 additions & 1 deletion test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
Expand Down Expand Up @@ -90,6 +90,10 @@ def cm_with_list(x):
counters["inductor"]["split_cat_norm"],
expected_split_norm_count,
)
if expected_split_norm_count > 0:
self.assertIn(
"split_cat_pattern_normalization_pass_pre_grad", optimus_scuba_log
)
counters.clear()

@patch
Expand Down Expand Up @@ -251,6 +255,10 @@ def split_getitem_out_of_order(x):
counters["inductor"]["consecutive_split_merged"],
expected_split_merged,
)
if expected_split_merged > 0:
self.assertIn(
"split_cat_pattern_merge_splits_pass_pre_grad", optimus_scuba_log
)
counters.clear()

@patch
Expand Down Expand Up @@ -1063,6 +1071,9 @@ def stack_tahn_unbind(x):
counters["inductor"]["stack_tahn_unbind_merged"],
expected_stack_tahn_unbind_merged,
)
self.assertIn(
"split_cat_pattern_merge_getitem_cat_pass_pre_grad", optimus_scuba_log
)
counters.clear()

def test_numpy_compat_normalization(self):
Expand Down
6 changes: 2 additions & 4 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@


counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
optimus_scuba_log: Dict[str, Any] = {}
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
Expand Down Expand Up @@ -1154,10 +1155,7 @@ def dict_keys_repr(const_keys, *, local) -> str:
GLOBAL_KEY_PREFIX = "__dict_key"


from torch._subclasses import ( # noqa: F401
FakeTensorMode,
UnsupportedFakeTensorException,
)
from torch._subclasses import UnsupportedFakeTensorException # noqa: F401


def wrap_fake_exception(fn):
Expand Down
24 changes: 17 additions & 7 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@
logging as dynamo_logging,
utils as dynamo_utils,
)
from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code
from torch._dynamo.utils import (
counters,
detect_fake_mode,
lazy_format_graph_code,
optimus_scuba_log,
)
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
from torch.fx.passes.fake_tensor_prop import FakeTensorProp

from .._dynamo.backends.common import aot_autograd
Expand Down Expand Up @@ -621,9 +627,11 @@ def fx_codegen_and_compile(
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s",
counters["inductor"],
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
signpost_event(
"optimus",
"compile_fx.post_grad_passes",
optimus_scuba_log,
)

with V.set_fake_mode(fake_mode):
Expand Down Expand Up @@ -1159,9 +1167,11 @@ def compile_fx(
)

model_ = pre_grad_passes(model_, example_inputs_)
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s",
counters["inductor"],
optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
signpost_event(
"optimus",
"compile_fx.pre_grad_passes",
optimus_scuba_log,
)

if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from torch._dynamo.utils import counters
from torch._utils_internal import print_graph

from .. import config
from ..pattern_matcher import (
Expand Down Expand Up @@ -936,7 +935,6 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):


def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
print_graph(graph, "Before group_batch fusion in pre grad pass.")
fusions: List[GroupBatchFusionBase] = []
# we keep all current pre grad fusions to keep
# current implementation, will remove this later
Expand Down Expand Up @@ -965,4 +963,3 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):

for rule in fusions:
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
17 changes: 9 additions & 8 deletions torch/_inductor/fx_passes/numeric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import torch.optim as optim
from torch._utils_internal import print_graph

from .. import config

Expand Down Expand Up @@ -43,8 +42,8 @@ def clean_memory() -> None:
def compare_dict_tensors(dict_base, dict_control, precision):
if len(set(dict_base.keys())) != len(set(dict_control.keys())):
logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
print_graph(dict_base.keys(), "keys before pre/post grad fx passes.")
print_graph(dict_control.keys(), "keys after pre/post grad fx passes.")
logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
return False
is_allclose = True
for key in dict_base.keys():
Expand All @@ -66,8 +65,8 @@ def compare_dict_tensors(dict_base, dict_control, precision):
logger.warning(
"Mismatch parameter values found before and after pre/post grad fx passes."
)
print_graph(dict_base[key], "value before pre/post grad fx passes.")
print_graph(dict_control[key], "value after pre/post grad fx passes.")
logger.debug("value before pre/post grad fx passes %s", dict_base[key])
logger.debug("value after pre/post grad fx passes %s", dict_control[key])
is_allclose = False
return is_allclose

Expand All @@ -92,9 +91,11 @@ def compare_tuple_tensors(tuple_base, tuple_control, precision):
atol=precision,
equal_nan=True,
):
print_graph(tuple_base[i], "forward output before pre/post grad fx passes.")
print_graph(
tuple_control[i], "forward output after pre/post grad fx passes."
logger.debug(
"forward output before pre/post grad fx passes %s", tuple_base[i]
)
logger.debug(
"forward output after pre/post grad fx passes %s", tuple_control[i]
)
is_allclose = False
return is_allclose
Expand Down
17 changes: 6 additions & 11 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools
import itertools
import logging
Expand All @@ -12,10 +13,11 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log

from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype

from torch._utils_internal import print_graph
from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq

from .. import config, ir, pattern_matcher
Expand Down Expand Up @@ -80,18 +82,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):

if config.pattern_matcher:
lazy_init()

print_graph(gm.graph, "Before group batch fusion in post grad pass.")
inductor_before_change = copy.deepcopy(counters["inductor"])
group_batch_fusion_passes(gm.graph, pre_grad=False)
print_graph(gm.graph, "After group batch fusion in post grad pass.")
if counters["inductor"] != inductor_before_change:
optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph)
remove_noop_ops(gm.graph)
print_graph(gm.graph, "Before split cat in post grad pass.")
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
)
if is_inference:
inference_patterns.apply(gm.graph) # type: ignore[arg-type]

Expand All @@ -112,8 +109,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.recompile()
gm.graph.lint()

print_graph(gm.graph, "After recompile in post grad pass.")


@init_once_fakemode
def lazy_init():
Expand Down
53 changes: 36 additions & 17 deletions torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

import torch
import torch.nn as nn
from torch._dynamo.utils import detect_fake_mode
from torch._utils_internal import print_graph
from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
from torch._utils_internal import upload_graph
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights

from .. import config

from ..fx_utils import matches_module_function_pattern
Expand All @@ -27,13 +28,27 @@

log = logging.getLogger(__name__)

normalization_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True)
split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
unbind_stack_pass = PatternMatcherPass(prevent_match_across_mutations=True)
efficient_conv_bn_eval_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_getitem_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
predispatch_pass = PatternMatcherPass(prevent_match_across_mutations=True)
normalization_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="normalization_pass"
)
merge_splits_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="merge_splits_pass"
)
split_cat_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="split_cat_pass"
)
unbind_stack_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="unbind_stack_pass"
)
efficient_conv_bn_eval_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass"
)
merge_getitem_cat_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass"
)
predispatch_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="predispatch_pass"
)
# based on predispatch aten IR
normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
Expand Down Expand Up @@ -114,17 +129,23 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
)
else:
# We only log the graph with changes to avoid the excessive compilation time
# https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)
print_graph(gm.graph, "Before group batch fusion in pre grad pass.")
inductor_before_change = copy.deepcopy(counters["inductor"])
group_batch_fusion_passes(gm.graph, pre_grad=True)
print_graph(gm.graph, "Before split cat in pre grad pass.")
if counters["inductor"] != inductor_before_change:
optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph(
gm.graph
)
for pattern_matcher_pass in pattern_matcher_passes:
inductor_before_change = copy.deepcopy(counters["inductor"])
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
)
if counters["inductor"] != inductor_before_change:
optimus_scuba_log[
f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad"
] = upload_graph(gm.graph)

if config.pre_grad_custom_pass is not None:
config.pre_grad_custom_pass(gm.graph)
Expand All @@ -148,8 +169,6 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
config.fx_passes_numeric_check.get("precision", 1e-4),
)

print_graph(gm.graph, "After recompile in pre grad pass.")

return gm


Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,12 +1214,15 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule):


class PatternMatcherPass:
def __init__(self, prevent_match_across_mutations=False):
def __init__(
self, prevent_match_across_mutations=False, pass_name: Optional[str] = None
):
super().__init__()
self.patterns: DefaultDict[
torch.fx.node.Target, List[PatternEntry]
] = defaultdict(list)
self.prevent_match_across_mutations = prevent_match_across_mutations
self.pass_name = pass_name

def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]:
return self.patterns[item]
Expand Down
2 changes: 1 addition & 1 deletion torch/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def log_compilation_event(metrics):
log.info("%s", metrics)


def print_graph(graph, msg: str):
def upload_graph(graph):
pass


Expand Down

0 comments on commit ecc6650

Please sign in to comment.