From 2e7ba0b194cdb9e1120c88ffc338d88bfcd7beab Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 22 Nov 2024 23:54:36 +0000 Subject: [PATCH] Revert "Switch to using Python nested int (#141166)" This reverts commit e2e8a7fa2e519433a4ec1071f80d2f6f843c6300. Reverted https://github.com/pytorch/pytorch/pull/141166 on behalf of https://github.com/clee2000 due to broke docs [GH job link](https://github.com/pytorch/pytorch/actions/runs/11980936976/job/33406870951) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/e2e8a7fa2e519433a4ec1071f80d2f6f843c6300) ([comment](https://github.com/pytorch/pytorch/pull/141166#issuecomment-2495112297)) --- test/test_nestedtensor.py | 95 ------------------ torch/_subclasses/fake_tensor.py | 3 +- torch/fx/experimental/constant_symnode.py | 72 -------------- torch/nested/_internal/nested_int.py | 113 ---------------------- torch/nested/_internal/nested_tensor.py | 3 +- 5 files changed, 2 insertions(+), 284 deletions(-) delete mode 100644 torch/fx/experimental/constant_symnode.py delete mode 100644 torch/nested/_internal/nested_int.py diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 66290ae548985..eeb835f132fc6 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8574,101 +8574,6 @@ def f(*args, **kwargs): self.assertEqualNoncontigAware(grads_compile, grads_ref) -from torch.nested._internal.nested_int import NestedIntNode - - -class TestNestedInt(torch.testing._internal.common_utils.TestCase): - def test_comparisons(self): - a = torch.SymInt(NestedIntNode(1, 1)) - b = torch.SymInt(NestedIntNode(1, 1)) - c = torch.SymInt(NestedIntNode(2, 1)) - d = 3 - - self.assertTrue(a == a) - self.assertTrue(a == b) - self.assertFalse(a != a) - self.assertFalse(a != b) - self.assertFalse(a == c) - self.assertTrue(a != c) - - self.assertFalse(a == d) - self.assertTrue(a != d) - self.assertFalse(d == a) - self.assertTrue(d != a) - - # ge - self.assertTrue(a >= a) - self.assertTrue(a >= b) - self.assertTrue(b >= a) - with self.assertRaises(ValueError): - _ = a >= c - with self.assertRaises(ValueError): - _ = c >= a - with self.assertRaises(ValueError): - _ = c >= 3 - self.assertTrue(c >= 2) - self.assertTrue(c >= 1) - self.assertFalse(c <= 1) - - # lt - self.assertFalse(a < a) - self.assertFalse(a < b) - self.assertFalse(b < a) - with self.assertRaises(ValueError): - _ = a < c - with self.assertRaises(ValueError): - _ = c < a - with self.assertRaises(ValueError): - _ = 3 < a - with self.assertRaises(ValueError): - _ = 2 < a - self.assertTrue(a > 1) - - # le - self.assertTrue(a <= a) - self.assertTrue(b <= a) - self.assertTrue(a <= b) - with self.assertRaises(ValueError): - _ = a <= c - with self.assertRaises(ValueError): - _ = c <= a - with self.assertRaises(ValueError): - _ = 3 <= c - self.assertTrue(c >= 2) - self.assertTrue(c >= 1) - self.assertFalse(c <= 1) - - # gt - self.assertFalse(a > a) - self.assertFalse(b > a) - self.assertFalse(a > b) - with self.assertRaises(ValueError): - _ = a > c - with self.assertRaises(ValueError): - _ = c > a - with self.assertRaises(ValueError): - _ = a > 3 - with self.assertRaises(ValueError): - _ = a > 2 - self.assertTrue(a > 1) - - def test_with_factor(self): - a = torch.SymInt(NestedIntNode(1, 5)) - b = torch.SymInt(NestedIntNode(1, 10)) - # eq - self.assertFalse(a == b) - self.assertFalse(a >= b) - self.assertTrue(b >= a) - self.assertTrue(a <= b) - self.assertFalse(b <= a) - # ne - self.assertTrue(a != b) - # mul - self.assertTrue(a * 2 == b) - self.assertTrue(a * 3 >= b) - self.assertTrue(a * 2 == 2 * a) - - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) instantiate_device_type_tests(TestNestedTensorAutograd, globals()) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 090a18ad9f55d..7132e3b6c65f1 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -2513,13 +2513,12 @@ def create_symbolic_nested_int( # See Note: [Creating symbolic nested int] # Returned nested int always has coeff=1; multiply the result by coeff if needed import torch.nested._internal.nested_tensor - from torch.nested._internal.nested_int import NestedIntNode if nt_tensor_id is None: nt_tensor_id = self.nt_tensor_id_counter assert self.enter_stack, "should only called while FakeTensorMode is active" self.nt_tensor_id_counter += 1 - hint = torch.SymInt(NestedIntNode(nt_tensor_id, 1)) + hint = torch._C._get_nested_int(nt_tensor_id, 1) src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths") assert self.shape_env is not None diff --git a/torch/fx/experimental/constant_symnode.py b/torch/fx/experimental/constant_symnode.py deleted file mode 100644 index f64e5abfe7c78..0000000000000 --- a/torch/fx/experimental/constant_symnode.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import * # noqa: F403 - - -__all__ = ["ConstantIntNode"] - - -# Python version of c10/core/ConstantSymNodeImpl.cpp -# This needs to exist because the Python version of nested int is not compatible -# with the C++ version of constant symnode. -class ConstantIntNode: - def __init__(self, val: int): - self.val = val - - def is_constant(self) -> bool: - return True - - def maybe_as_int(self) -> int: - return self.val - - def is_int(self) -> bool: - return True - - def is_float(self) -> bool: - return False - - def is_bool(self) -> bool: - return False - - def is_nested_int(self) -> bool: - return False - - def clone(self) -> "ConstantIntNode": - return self - - def _str(self) -> str: - return str(self.val) - - def __str__(self) -> str: - return self._str() - - def __repr__(self) -> str: - return self._str() - - def _graph_repr(self) -> str: - return self._str() - - def mul(self, other: Any) -> Any: - return other.mul(self) - - def eq(self, other: Any) -> Any: - return other.eq(self) - - def ne(self, other: Any) -> Any: - return other.ne(self) - - def gt(self, other: Any) -> Any: - return other.lt(self) - - def lt(self, other: Any) -> Any: - return other.gt(self) - - def le(self, other: Any) -> Any: - return other.ge(self) - - def ge(self, other: Any) -> Any: - return other.le(self) - - def is_symbolic(self) -> bool: - return False - - def constant_int(self) -> int: - return self.val diff --git a/torch/nested/_internal/nested_int.py b/torch/nested/_internal/nested_int.py deleted file mode 100644 index 59ebc8b1a5de0..0000000000000 --- a/torch/nested/_internal/nested_int.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import * # noqa: F403 - -import torch -from torch.fx.experimental.constant_symnode import ConstantIntNode - - -__all__ = ["NestedIntNode"] - - -# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp -def _eq(lhs: Any, rhs: Any) -> bool: - return ( - isinstance(lhs, NestedIntNode) - and isinstance(rhs, NestedIntNode) - and lhs.t_id == rhs.t_id - and lhs.coeff == rhs.coeff - ) - - -def _ge(lhs: Any, rhs: Any) -> bool: - if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode): - if lhs.t_id == rhs.t_id: - return lhs.coeff >= rhs.coeff - raise ValueError("ge: relation is indeterminate") - elif isinstance(lhs, NestedIntNode): - if rhs.is_constant() and rhs.constant_int() <= 2: - return True - raise ValueError("ge: relation is indeterminate") - elif isinstance(rhs, NestedIntNode): - if lhs.is_constant() and lhs.constant_int() < 2: - return False - raise ValueError("ge: relation is indeterminate") - else: - raise ValueError("inputs unsupported") - - -class NestedIntNode: - def __init__(self, t_id: int, coeff: int): - self.t_id = t_id - self.coeff = coeff - - def nested_int_coeff(self) -> int: - return self.coeff - - def maybe_as_int(self) -> Optional[int]: - return None - - def is_int(self) -> bool: - return True - - def is_float(self) -> bool: - return False - - def is_bool(self) -> bool: - return False - - def is_nested_int(self) -> bool: - return True - - def clone(self) -> "NestedIntNode": - return self - - def _str(self) -> str: - if self.coeff == 1: - return f"j{self.t_id}" - return f"{self.coeff}*j{self.t_id}" - - def __str__(self) -> str: - return self._str() - - def __repr__(self) -> str: - return self._str() - - def _graph_repr(self) -> str: - return self._str() - - def mul(self, other: Any) -> "NestedIntNode": - if other.is_constant(): - other = other.constant_int() - else: - raise ValueError(f"unsupported: {type(other)}") - return NestedIntNode(self.t_id, self.coeff * other) - - def eq(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(_eq(self, other)) - - def ne(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(not _eq(self, other)) - - def gt(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(not _ge(other, self)) - - def lt(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(not _ge(self, other)) - - def le(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(_ge(other, self)) - - def ge(self, other: Any) -> Any: - return torch._C._get_constant_bool_symnode(_ge(self, other)) - - def is_symbolic(self) -> bool: - return False - - def nested_int(self) -> int: - return self.t_id - - def is_constant(self) -> bool: - return False - - def wrap_int(self, num: int) -> ConstantIntNode: - assert type(num) is int - return ConstantIntNode(num) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 7f1b81c099d88..d33ee7f9f9ea3 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -5,7 +5,6 @@ import torch from torch._C import DispatchKey, DispatchKeySet from torch._prims_common import is_expandable_to -from torch.nested._internal.nested_int import NestedIntNode from torch.utils.weak import WeakTensorKeyDictionary @@ -26,7 +25,7 @@ def get_tensor_symint(tensor, *, coeff=1): tensor_symint = _tensor_symint_registry.get(tensor) if tensor_symint is None: - tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff)) + tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff) _tensor_id_counter += 1 _tensor_symint_registry[tensor] = tensor_symint return tensor_symint