Skip to content

Commit

Permalink
[PT2][Optimus][Observability] Log the optimus graph transformation to…
Browse files Browse the repository at this point in the history
… the scuba (pytorch#119745)

Summary:
X-link: pytorch/benchmark#2163


Current everstore upload logging may cuase excessive compilation time when the model has lots of graph breaks (post: https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/), we here log the transformation only when the graph changed

Test Plan:
# unit test
```
buck2 test //caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/8dc77120-76ea-41bf-83b4-ada88bf1e2c8
Test UI: https://www.internalfb.com/intern/testinfra/testrun/7036874622840948
Network: Up: 321KiB  Down: 33MiB  (reSessionID-5372a58f-8b66-43a9-82b4-ee23aeaa44b9)
Jobs completed: 20. Time elapsed: 4:13.4s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0

```
buck2 test //caffe2/test/inductor:split_cat_fx_passes
```
Buck UI: https://www.internalfb.com/buck2/abc8b8f8-d240-47d3-ad9d-cae13b8e62d3
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229323926967362
Network: Up: 119KiB  Down: 66KiB  (reSessionID-8bdd34b3-159b-469f-9f00-8384620c13ea)
Jobs completed: 28. Time elapsed: 3:02.9s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 11. Fail 0. Fatal 0. Skip 0. Build failure 0



# e2e
baseline:
f528209775
proposal:
f531285723

scuba: https://fburl.com/scuba/workflow_signpost/7hamzr64

 {F1456548774}

Reviewed By: jackiexu1992

Differential Revision: D53692344
  • Loading branch information
mengluy0125 authored and facebook-github-bot committed Feb 15, 2024
1 parent 0898ead commit dbd9ac6
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 76 deletions.
42 changes: 20 additions & 22 deletions test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torch._inductor
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch._dynamo.utils import counters, optimus_scuba_log

try:
# importing this will register fbgemm lowerings for inductor
Expand All @@ -18,11 +17,8 @@
has_fbgemm = False
pass

requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")


class MyModule(torch.nn.Module):
def __init__(self, z: int, has_bias: bool, device="cuda") -> None:
def __init__(self, z: int, has_bias: bool, device="cpu") -> None:
super().__init__()
self.z = z
self.device = device
Expand Down Expand Up @@ -220,7 +216,6 @@ def forward(self, x):
return torch.cat(div, dim=1)


@requires_cuda
@torch._inductor.config.patch(
pre_grad_fusion_options={
"batch_linear": {},
Expand Down Expand Up @@ -271,8 +266,8 @@ def test_group_linear_fusion(self):
z = 10
for has_bias in [True, False]:
counters.clear()
module = MyModule(z, has_bias).to("cuda")
input = [torch.randn(z, z, device="cuda")]
module = MyModule(z, has_bias).to("cpu")
input = [torch.randn(z, z, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand All @@ -285,6 +280,7 @@ def test_group_linear_fusion(self):
counters["inductor"]["batch_fusion"],
0,
)
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
Expand All @@ -297,13 +293,14 @@ def test_group_linear_fusion(self):
counters["inductor"]["batch_fusion"],
3,
)
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)
counters.clear()

@unittest.skipIf(not has_fbgemm, "requires fbgemm")
def test_group_linear_fusion_different_shapes(self):
counters.clear()
module = MyModule2().eval().to("cuda")
input = [torch.rand(4, 24, device="cuda")]
module = MyModule2().eval().to("cpu")
input = [torch.rand(4, 24, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand Down Expand Up @@ -334,8 +331,8 @@ def test_batch_layer_norm_fusion(self):
for has_weight in [True, False]:
for has_bias in [True, False]:
counters.clear()
module = MyModule3("cuda", has_weight, has_bias).to("cuda")
input = [torch.randn(2, 5, 50, device="cuda")]
module = MyModule3("cpu", has_weight, has_bias).to("cpu")
input = [torch.randn(2, 5, 50, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand Down Expand Up @@ -363,8 +360,8 @@ def test_batch_linear_lhs_fusion(self):
z = 10
for has_bias in [True, False]:
counters.clear()
module = MyModule4(z, "cuda", has_bias)
input = [torch.randn(20, z, device="cuda")]
module = MyModule4(z, "cpu", has_bias)
input = [torch.randn(20, z, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand All @@ -387,8 +384,8 @@ def test_batch_linear_lhs_fusion(self):
def test_batch_linear_pre_grad_fusion(self):
for has_bias in [True, False]:
counters.clear()
module = MyModule5("cuda", has_bias)
input = [torch.randn(50, 500, device="cuda")]
module = MyModule5("cpu", has_bias)
input = [torch.randn(50, 500, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand All @@ -411,8 +408,8 @@ def test_batch_linear_pre_grad_fusion(self):

def test_pointwise_op_fusion(self):
counters.clear()
module = TestPoitwiseOps("cuda")
input = [torch.randn(50, 1000, requires_grad=True, device="cuda")]
module = TestPoitwiseOps("cpu")
input = [torch.randn(50, 1000, requires_grad=True, device="cpu")]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
Expand Down Expand Up @@ -450,16 +447,15 @@ def forward(self, inputs):
return output


@requires_cuda
@torch._inductor.config.patch(
post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}}
)
class TestPostGradBatchLinearFusion(TestCase):
def test_batch_linear_post_grad_fusion(self):
pt1_module = TestBMMFusionModule().cuda()
pt1_module = TestBMMFusionModule()
inputs = []
for _ in range(10):
inputs.append(torch.randn(10, 10).cuda())
inputs.append(torch.randn(10, 10))
eager_output = pt1_module(inputs)
pt2_module = torch.compile(pt1_module)
pt2_output = pt2_module(inputs)
Expand All @@ -468,6 +464,8 @@ def test_batch_linear_post_grad_fusion(self):
counters["inductor"]["batch_fusion"],
2,
)
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
self.assertIn("group_batch_fusion_post_grad", optimus_scuba_log)


class TestFindIndependentSubsetGreedy(TestCase):
Expand Down
13 changes: 11 additions & 2 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
Expand Down Expand Up @@ -90,6 +90,10 @@ def cm_with_list(x):
counters["inductor"]["split_cat_norm"],
expected_split_norm_count,
)
if expected_split_norm_count > 0:
self.assertIn(
"split_cat_pattern_normalization_pass_pre_grad", optimus_scuba_log
)
counters.clear()

@patch
Expand Down Expand Up @@ -251,6 +255,10 @@ def split_getitem_out_of_order(x):
counters["inductor"]["consecutive_split_merged"],
expected_split_merged,
)
if expected_split_merged > 0:
self.assertIn(
"split_cat_pattern_merge_splits_pass_pre_grad", optimus_scuba_log
)
counters.clear()

@patch
Expand Down Expand Up @@ -1063,6 +1071,7 @@ def stack_tahn_unbind(x):
counters["inductor"]["stack_tahn_unbind_merged"],
expected_stack_tahn_unbind_merged,
)
self.assertIn("split_cat_pattern_merge_getitem_cat_pass_pre_grad", optimus_scuba_log)
counters.clear()

def test_numpy_compat_normalization(self):
Expand All @@ -1087,7 +1096,7 @@ def test_stack_normalization_axis_kwarg(self):
def fn(x, y):
return torch.stack([x, y], axis=1)

x, y = (torch.rand((4, 4), device="cuda") for _ in range(2))
x, y = (torch.rand((4, 4), device="cpu") for _ in range(2))
expected = fn(x, y)
actual = torch.compile(fn)(x, y)

Expand Down
6 changes: 2 additions & 4 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@


counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
optimus_scuba_log: Dict[str, Any] = {}
troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html"
nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
Expand Down Expand Up @@ -1154,10 +1155,7 @@ def dict_keys_repr(const_keys, *, local) -> str:
GLOBAL_KEY_PREFIX = "__dict_key"


from torch._subclasses import ( # noqa: F401
FakeTensorMode,
UnsupportedFakeTensorException,
)
from torch._subclasses import UnsupportedFakeTensorException # noqa: F401


def wrap_fake_exception(fn):
Expand Down
24 changes: 17 additions & 7 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@
logging as dynamo_logging,
utils as dynamo_utils,
)
from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code
from torch._dynamo.utils import (
counters,
detect_fake_mode,
lazy_format_graph_code,
optimus_scuba_log,
)
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache

from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
from torch.fx.passes.fake_tensor_prop import FakeTensorProp

from .._dynamo.backends.common import aot_autograd
Expand Down Expand Up @@ -621,9 +627,11 @@ def fx_codegen_and_compile(
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s",
counters["inductor"],
optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
signpost_event(
"optimus",
"compile_fx.post_grad_passes",
optimus_scuba_log,
)

with V.set_fake_mode(fake_mode):
Expand Down Expand Up @@ -1159,9 +1167,11 @@ def compile_fx(
)

model_ = pre_grad_passes(model_, example_inputs_)
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s",
counters["inductor"],
optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
signpost_event(
"optimus",
"compile_fx.pre_grad_passes",
optimus_scuba_log,
)

if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from torch._dynamo.utils import counters
from torch._utils_internal import print_graph

from .. import config
from ..pattern_matcher import (
Expand Down Expand Up @@ -936,7 +935,6 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):


def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
print_graph(graph, "Before group_batch fusion in pre grad pass.")
fusions: List[GroupBatchFusionBase] = []
# we keep all current pre grad fusions to keep
# current implementation, will remove this later
Expand Down Expand Up @@ -965,4 +963,3 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):

for rule in fusions:
apply_group_batch_fusion(graph, rule) # type: ignore[arg-type]
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
17 changes: 9 additions & 8 deletions torch/_inductor/fx_passes/numeric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import torch.optim as optim
from torch._utils_internal import print_graph

from .. import config

Expand Down Expand Up @@ -43,8 +42,8 @@ def clean_memory() -> None:
def compare_dict_tensors(dict_base, dict_control, precision):
if len(set(dict_base.keys())) != len(set(dict_control.keys())):
logger.warning("Mismatch keys found before and after pre/post grad fx passes.")
print_graph(dict_base.keys(), "keys before pre/post grad fx passes.")
print_graph(dict_control.keys(), "keys after pre/post grad fx passes.")
logger.debug("keys before pre/post grad fx passes %s", dict_base.keys())
logger.debug("keys after pre/post grad fx passes %s", dict_control.keys())
return False
is_allclose = True
for key in dict_base.keys():
Expand All @@ -66,8 +65,8 @@ def compare_dict_tensors(dict_base, dict_control, precision):
logger.warning(
"Mismatch parameter values found before and after pre/post grad fx passes."
)
print_graph(dict_base[key], "value before pre/post grad fx passes.")
print_graph(dict_control[key], "value after pre/post grad fx passes.")
logger.debug("value before pre/post grad fx passes %s", dict_base[key])
logger.debug("value after pre/post grad fx passes %s", dict_control[key])
is_allclose = False
return is_allclose

Expand All @@ -92,9 +91,11 @@ def compare_tuple_tensors(tuple_base, tuple_control, precision):
atol=precision,
equal_nan=True,
):
print_graph(tuple_base[i], "forward output before pre/post grad fx passes.")
print_graph(
tuple_control[i], "forward output after pre/post grad fx passes."
logger.debug(
"forward output before pre/post grad fx passes %s", tuple_base[i]
)
logger.debug(
"forward output after pre/post grad fx passes %s", tuple_control[i]
)
is_allclose = False
return is_allclose
Expand Down
17 changes: 6 additions & 11 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools
import itertools
import logging
Expand All @@ -12,10 +13,11 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log

from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype

from torch._utils_internal import print_graph
from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq

from .. import config, ir, pattern_matcher
Expand Down Expand Up @@ -80,18 +82,13 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):

if config.pattern_matcher:
lazy_init()

print_graph(gm.graph, "Before group batch fusion in post grad pass.")
inductor_before_change = copy.deepcopy(counters["inductor"])
group_batch_fusion_passes(gm.graph, pre_grad=False)
print_graph(gm.graph, "After group batch fusion in post grad pass.")
if counters["inductor"] != inductor_before_change:
optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph)
remove_noop_ops(gm.graph)
print_graph(gm.graph, "Before split cat in post grad pass.")
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
)
if is_inference:
inference_patterns.apply(gm.graph) # type: ignore[arg-type]

Expand All @@ -112,8 +109,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.recompile()
gm.graph.lint()

print_graph(gm.graph, "After recompile in post grad pass.")


@init_once_fakemode
def lazy_init():
Expand Down
Loading

0 comments on commit dbd9ac6

Please sign in to comment.