Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List decompositions for torch.export #26878

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# flake8: noqa
# mypy: ignore-errors

import logging
import torch

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
from openvino.runtime import PartialShape, Type as OVType, OVAny, Shape
from openvino.frontend.pytorch.utils import make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const

import torch

import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def convolution_backward(

return grad_input, grad_weight, grad_bias


if len(get_decompositions([aten._scaled_dot_product_flash_attention.default])) == 0:

@register_decomposition(aten._scaled_dot_product_flash_attention.default)
def scaled_dot_product_flash_attention(
query,
Expand Down Expand Up @@ -101,16 +103,197 @@ def scaled_dot_product_flash_attention(


def get_aot_decomposition_list():
return ([torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.gelu_backward.default,
torch.ops.aten.native_group_norm.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default])
return [
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._softmax.default,
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten.convolution_backward.default,
torch.ops.aten.gelu_backward.default,
torch.ops.aten.native_group_norm.default,
torch.ops.aten.native_group_norm_backward.default,
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default,
]


def get_inf_decomposition_list():
return ([torch.ops.aten.nll_loss_forward.default])
return [torch.ops.aten.nll_loss_forward.default]


def get_export_decomposition_list():
# List of decompositions from torch._decomp.core_aten_decompositions
# removed _backward ops and ops supported without decomposition
decomp = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we going to maintain this list from one version to another?
Should we really need it for models conversion or is it for NNCF tool only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list changes not very often, we do not really need to always have it up to date. We may even reduce it in future. The main reason this list is needed is that SDPA gets decomposed by default but we want it to exist as a single operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check in the test that SDPA op is really created in the original graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decompositions happen inside convert_model. Checking that operation exist would require patching ovc, also would need to decide how to split ts ops from fx ops when writing a test. This solution is outside of scope of this PR. I suggest solving that with a different PR

torch.ops.aten.addcdiv,
torch.ops.aten.addcdiv_,
torch.ops.aten.addcmul,
torch.ops.aten.addcmul_,
torch.ops.aten.addr,
torch.ops.aten.affine_grid_generator,
torch.ops.aten.all,
torch.ops.aten.aminmax,
torch.ops.aten.arange.default,
torch.ops.aten.arange.start,
torch.ops.aten.baddbmm,
torch.ops.aten.binary_cross_entropy,
torch.ops.aten.binary_cross_entropy_with_logits,
torch.ops.aten.block_diag,
torch.ops.aten.celu,
torch.ops.aten.celu_,
torch.ops.aten.clamp_max,
torch.ops.aten.clamp_min,
torch.ops.aten.count_nonzero,
torch.ops.aten.linalg_cross,
torch.ops.aten.cudnn_batch_norm,
torch.ops.aten.deg2rad,
torch.ops.aten.deg2rad_,
torch.ops.aten.detach,
torch.ops.aten.diag_embed,
torch.ops.aten.dot,
torch.ops.aten.vdot,
torch.ops.aten.elu,
torch.ops.aten.elu_,
torch.ops.aten._embedding_bag,
torch.ops.aten.empty_like,
torch.ops.aten._euclidean_dist.default,
torch.ops.aten.expand_as,
torch.ops.aten.eye,
torch.ops.aten.fill,
torch.ops.aten.fill_,
torch.ops.aten.floor_divide,
torch.ops.aten.frac,
torch.ops.aten.frac_,
torch.ops.aten._fused_moving_avg_obs_fq_helper,
torch.ops.aten.gelu_,
torch.ops.aten.glu,
torch.ops.aten.hardshrink,
torch.ops.aten.hardsigmoid,
torch.ops.aten.hardsigmoid_,
torch.ops.aten.hardswish,
torch.ops.aten.hardswish_,
torch.ops.aten.hardtanh_,
torch.ops.aten.heaviside,
torch.ops.aten.heaviside_,
torch.ops.aten.huber_loss,
torch.ops.aten.im2col,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
torch.ops.aten.index_copy,
torch.ops.aten.index_copy_,
torch.ops.aten.index_fill,
torch.ops.aten.index_fill_,
torch.ops.aten.isin,
torch.ops.aten.isneginf,
torch.ops.aten.isposinf,
torch.ops.aten.l1_loss,
torch.ops.aten.leaky_relu_,
torch.ops.aten.lerp,
torch.ops.aten.lerp_,
torch.ops.aten.linspace,
torch.ops.aten.logaddexp,
torch.ops.aten.logaddexp2,
torch.ops.aten.logit,
torch.ops.aten.logit_,
torch.ops.aten.log_sigmoid_forward,
torch.ops.aten.logspace,
torch.ops.aten.logsumexp.default,
torch.ops.aten.masked_fill,
torch.ops.aten.masked_fill_,
torch.ops.aten.mish,
torch.ops.aten.mish_,
torch.ops.aten.mse_loss,
torch.ops.aten.multi_margin_loss,
torch.ops.aten.multilabel_margin_loss_forward,
torch.ops.aten.mv,
torch.ops.aten.mvlgamma,
torch.ops.aten.mvlgamma_,
torch.ops.aten.nansum,
torch.ops.aten.nan_to_num,
torch.ops.aten.nan_to_num_,
torch.ops.aten.narrow,
torch.ops.aten.new_empty,
torch.ops.aten.new_full,
torch.ops.aten.new_ones,
torch.ops.aten.new_zeros,
torch.ops.aten.nll_loss_forward,
torch.ops.aten.norm,
torch.ops.aten.ones,
torch.ops.aten.ones_like,
torch.ops.aten._prelu_kernel,
torch.ops.aten._reshape_alias,
torch.ops.aten.rad2deg,
torch.ops.aten.rad2deg_,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
torch.ops.aten.renorm,
torch.ops.aten.renorm_,
torch.ops.aten.resize_as,
torch.ops.aten.roll,
torch.ops.aten.rot90,
torch.ops.aten.rrelu_with_noise,
torch.ops.aten.rrelu_with_noise_,
torch.ops.aten.rsub,
torch.ops.aten.select_scatter,
torch.ops.aten.sgn,
torch.ops.aten.sgn_,
torch.ops.aten.silu,
torch.ops.aten.silu_,
torch.ops.aten.sinc,
torch.ops.aten.sinc_,
torch.ops.aten.smooth_l1_loss,
torch.ops.aten.soft_margin_loss,
torch.ops.aten.softplus,
torch.ops.aten.softshrink,
torch.ops.aten.special_entr,
torch.ops.aten.special_log_ndtr,
torch.ops.aten.special_xlog1py,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes_copy,
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.std,
torch.ops.aten.std_mean,
torch.ops.aten.stack,
torch.ops.aten.sum.default,
torch.ops.aten.sum.out,
torch.ops.aten.t,
torch.ops.aten.take,
torch.ops.aten.threshold,
torch.ops.aten.threshold_,
torch.ops.aten.trace,
torch.ops.aten.transpose.int,
torch.ops.aten.tril,
torch.ops.aten.tril_,
torch.ops.aten.triu,
torch.ops.aten.triu_,
torch.ops.aten.unbind,
torch.ops.aten.unfold_copy,
torch.ops.aten._unsafe_index,
torch.ops.aten.unsafe_split.Tensor,
torch.ops.aten.unsafe_split_with_sizes,
torch.ops.aten._unsafe_view,
torch.ops.aten.view_as_complex,
torch.ops.aten.xlogy,
torch.ops.aten.xlogy_,
torch.ops.aten.zero,
torch.ops.aten.zero_,
torch.ops.aten.zeros,
torch.ops.aten.zeros_like,
torch.ops.aten._weight_norm_interface,
]
try:
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.3"):
decomp += [
torch.ops.aten._lazy_clone,
torch.ops.aten._test_parallel_materialize,
torch.ops.aten._chunk_cat,
]
except ImportError:
pass
return decomp
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.clamp_min.default", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
{"aten.clamp_min.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Maximum>},
{"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten.col2im.default", op::translate_col2im},
{"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx},
{"aten.convolution.default", op::translate_convolution},
{"aten.copy.default", op::translate_copy_fx},
Expand Down
19 changes: 8 additions & 11 deletions tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import warnings
from copy import deepcopy
import os

import torch
import pytest
import logging
import numpy as np

from common.constants import test_device, test_precision
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

from openvino.frontend import FrontEndManager
from openvino.runtime import Core, Type, PartialShape
import openvino.properties.hint as hints
import torch
from packaging import version
import pytest

logging.basicConfig(level=logging.DEBUG)


def skip_check(param):
Expand Down Expand Up @@ -124,13 +125,9 @@ def numpy_to_torch_recursively(x):
from torch.export import export

em = export(model, tuple(torch_inputs))
if version.parse(torch.__version__) >= version.parse("2.3"):
em = em.run_decompositions()
gm = em.module()
print(gm.code)

converted_model = convert_model(
em, example_input=torch_inputs)
em, example_input=torch_inputs, verbose=True)
self._resolve_input_shape_dtype(
converted_model, ov_inputs, dynamic_shapes)
smodel = model
Expand Down Expand Up @@ -242,7 +239,7 @@ def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_i
if not dynamic_shapes:
input_shapes = [inp.shape for inp in ov_inputs]
kwargs["input"] = input_shapes
om = convert_model(decoder, **kwargs)
om = convert_model(decoder, verbose=True, **kwargs)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return smodel, om

Expand Down
1 change: 1 addition & 0 deletions tests/layer_tests/pytorch_tests/test_col2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def forward(self, x):

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("output_size,kernel_size", [([4, 5], [2, 2])])
@pytest.mark.parametrize("dilation", [1, 2, [1, 2]])
@pytest.mark.parametrize("padding", [0, 5, [2, 3]])
Expand Down
20 changes: 11 additions & 9 deletions tests/layer_tests/pytorch_tests/test_eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from packaging import version

from pytorch_layer_test_class import PytorchLayerTest

Expand All @@ -14,7 +15,6 @@ def _prepare_input(self, m, n=None):
return (np.array(m, dtype="int32"), )
return (np.array(m, dtype="int32"), np.array(n, dtype="int32"))


def create_model(self, num_inputs, dtype):
import torch
dtype_map = {
Expand Down Expand Up @@ -45,29 +45,31 @@ def __init__(self, dtype):
def forward(self, x, y):
return torch.eye(x, y, dtype=self.dtype)


ref_net = None

return aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype), ref_net, ("aten::eye", "aten::IntImplicit")
model = aten_eye_1_input(pt_dtype) if num_inputs == 1 else aten_eye_2_inputs(pt_dtype)
return model, None, ["aten::eye", "aten::IntImplicit"]

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
@pytest.mark.parametrize("m", [2, 3, 4, 5])
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
def test_eye_square(self, dtype, m, ie_device, precision, ir_version):
if PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3"):
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
if ie_device == "GPU":
pytest.xfail(reason="eye is not supported on GPU")
self._test(*self.create_model(1, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m})
self._test(*self.create_model(1, dtype), ie_device, precision,
ir_version, kwargs_to_prepare_input={"m": m})

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.parametrize("dtype", ["bool", "int8", "uint8", "int32", "int64", "float32", "float64"])
@pytest.mark.parametrize(("m", "n"), [[2, 2], [3, 4], [5, 3]])
@pytest.mark.skipif(torch.__version__ < '2.3.0', reason="`aten.eye` is not supported in PyTorch versions earlier than 2.3.")
def test_eye(self, dtype, m, n, ie_device, precision, ir_version):
if (PytorchLayerTest.use_torch_export() and version.parse(torch.__version__) < version.parse("2.3")):
pytest.skip("Not supported in PyTorch versions earlier than 2.3.")
if ie_device == "GPU":
pytest.xfail(reason="eye is not supported on GPU")
self._test(*self.create_model(2, dtype), ie_device, precision, ir_version, kwargs_to_prepare_input={"m": m, "n": n})
self._test(*self.create_model(2, dtype), ie_device, precision,
ir_version, kwargs_to_prepare_input={"m": m, "n": n})
5 changes: 4 additions & 1 deletion tests/model_hub_tests/pytorch/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def convert_model_impl(self, model_obj):
pt_res = model_obj(**self.example)
graph = export(model_obj, tuple(), self.example)
if version.parse(torch.__version__) >= version.parse("2.2"):
graph = graph.run_decompositions()
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
graph = graph.run_decompositions(decomp_table=decomp)

gm = graph.module()
print(gm.code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,20 @@ def extract_module_extensions(args):
except:
pass
if not is_good_version:
raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
raise RuntimeError("NNCF models produced by nncf<2.6 are not "
"supported directly. Please upgrade nncf or "
"export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs)
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.2"):
model = model.run_decompositions()
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
model = model.run_decompositions(decomp_table=decomp)
gm = model.module()
log.debug(gm.code)
decoder = TorchFXPythonDecoder(gm)
else:
decoder = TorchScriptPythonDecoder(
Expand Down
Loading