Skip to content

Commit

Permalink
Revert "Switch to using Python nested int (pytorch#141166)"
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Nov 22, 2024
1 parent ee7eaad commit 2e7ba0b
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 284 deletions.
95 changes: 0 additions & 95 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8574,101 +8574,6 @@ def f(*args, **kwargs):
self.assertEqualNoncontigAware(grads_compile, grads_ref)


from torch.nested._internal.nested_int import NestedIntNode


class TestNestedInt(torch.testing._internal.common_utils.TestCase):
def test_comparisons(self):
a = torch.SymInt(NestedIntNode(1, 1))
b = torch.SymInt(NestedIntNode(1, 1))
c = torch.SymInt(NestedIntNode(2, 1))
d = 3

self.assertTrue(a == a)
self.assertTrue(a == b)
self.assertFalse(a != a)
self.assertFalse(a != b)
self.assertFalse(a == c)
self.assertTrue(a != c)

self.assertFalse(a == d)
self.assertTrue(a != d)
self.assertFalse(d == a)
self.assertTrue(d != a)

# ge
self.assertTrue(a >= a)
self.assertTrue(a >= b)
self.assertTrue(b >= a)
with self.assertRaises(ValueError):
_ = a >= c
with self.assertRaises(ValueError):
_ = c >= a
with self.assertRaises(ValueError):
_ = c >= 3
self.assertTrue(c >= 2)
self.assertTrue(c >= 1)
self.assertFalse(c <= 1)

# lt
self.assertFalse(a < a)
self.assertFalse(a < b)
self.assertFalse(b < a)
with self.assertRaises(ValueError):
_ = a < c
with self.assertRaises(ValueError):
_ = c < a
with self.assertRaises(ValueError):
_ = 3 < a
with self.assertRaises(ValueError):
_ = 2 < a
self.assertTrue(a > 1)

# le
self.assertTrue(a <= a)
self.assertTrue(b <= a)
self.assertTrue(a <= b)
with self.assertRaises(ValueError):
_ = a <= c
with self.assertRaises(ValueError):
_ = c <= a
with self.assertRaises(ValueError):
_ = 3 <= c
self.assertTrue(c >= 2)
self.assertTrue(c >= 1)
self.assertFalse(c <= 1)

# gt
self.assertFalse(a > a)
self.assertFalse(b > a)
self.assertFalse(a > b)
with self.assertRaises(ValueError):
_ = a > c
with self.assertRaises(ValueError):
_ = c > a
with self.assertRaises(ValueError):
_ = a > 3
with self.assertRaises(ValueError):
_ = a > 2
self.assertTrue(a > 1)

def test_with_factor(self):
a = torch.SymInt(NestedIntNode(1, 5))
b = torch.SymInt(NestedIntNode(1, 10))
# eq
self.assertFalse(a == b)
self.assertFalse(a >= b)
self.assertTrue(b >= a)
self.assertTrue(a <= b)
self.assertFalse(b <= a)
# ne
self.assertTrue(a != b)
# mul
self.assertTrue(a * 2 == b)
self.assertTrue(a * 3 >= b)
self.assertTrue(a * 2 == 2 * a)


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
Expand Down
3 changes: 1 addition & 2 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,13 +2513,12 @@ def create_symbolic_nested_int(
# See Note: [Creating symbolic nested int]
# Returned nested int always has coeff=1; multiply the result by coeff if needed
import torch.nested._internal.nested_tensor
from torch.nested._internal.nested_int import NestedIntNode

if nt_tensor_id is None:
nt_tensor_id = self.nt_tensor_id_counter
assert self.enter_stack, "should only called while FakeTensorMode is active"
self.nt_tensor_id_counter += 1
hint = torch.SymInt(NestedIntNode(nt_tensor_id, 1))
hint = torch._C._get_nested_int(nt_tensor_id, 1)

src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths")
assert self.shape_env is not None
Expand Down
72 changes: 0 additions & 72 deletions torch/fx/experimental/constant_symnode.py

This file was deleted.

113 changes: 0 additions & 113 deletions torch/nested/_internal/nested_int.py

This file was deleted.

3 changes: 1 addition & 2 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.nested._internal.nested_int import NestedIntNode
from torch.utils.weak import WeakTensorKeyDictionary


Expand All @@ -26,7 +25,7 @@ def get_tensor_symint(tensor, *, coeff=1):

tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff))
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
Expand Down

0 comments on commit 2e7ba0b

Please sign in to comment.