Skip to content

Commit

Permalink
Remove path that is not needed for SDPA decomp
Browse files Browse the repository at this point in the history
Summary:
Since we changed the flow to run decomps before the quantizer, we don't need the extra SDPA path now (which was one of the points!). Remove it here.
`ReplaceSafeSoftmaxWithSoftmax` is retained since it's used by Helios for now.

Differential Revision: D67561688
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Dec 27, 2024
1 parent 3aebf2f commit f984fe2
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 24 deletions.
16 changes: 0 additions & 16 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer

from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
MemoryConfig,
model_gm_has_SDPA,
model_is_quantized,
)
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.devtools import generate_etrecord
from executorch.exir import (
EdgeCompileConfig,
Expand Down Expand Up @@ -91,16 +85,6 @@ def convert_pt2(
.module()
)

if model_gm_has_SDPA(model_gm):
# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_gm)

# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
# for details).
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
assert result is not None
model_gm = result.graph_module

# Prepare
prepared_model = prepare_pt2e(model_gm, quantizer)

Expand Down
8 changes: 0 additions & 8 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,6 @@ def print_ops_info(
)


def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
for node in model_gm.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
return True
return False


def save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
Expand Down

0 comments on commit f984fe2

Please sign in to comment.