From 1001c18dc8ccb2ac56700fdb8ac1df3c73879f7d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 16:42:40 +0000 Subject: [PATCH] [BE] Check ordering and exclusivity of tensorclass registers ghstack-source-id: becd6b07c03eccaab2733e604b3dfb21ec05ebb6 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1176 --- test/test_tensorclass.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 91b7dbd99..c64c1aea4 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -121,7 +121,6 @@ def test_tensorclass_stub_methods(): ] if missing_methods: - raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") raise Exception( f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" ) @@ -150,6 +149,32 @@ class X: ) +def test_sorted_methods(): + from tensordict.tensorclass import ( + _FALLBACK_METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + ) + + lists_to_check = [ + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD, + ] + # Check that each list is sorted and has unique elements + for lst in lists_to_check: + assert lst == sorted(lst), f"List {lst} is not sorted" + assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements" + # Check that no two lists share any elements + for i, lst1 in enumerate(lists_to_check): + for j, lst2 in enumerate(lists_to_check): + if i != j: + shared_elements = set(lst1) & set(lst2) + assert ( + not shared_elements + ), f"Lists {lst1} and {lst2} share elements: {shared_elements}" def _make_data(shape):