diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 91b7dbd99..c64c1aea4 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -121,7 +121,6 @@ def test_tensorclass_stub_methods(): ] if missing_methods: - raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") raise Exception( f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" ) @@ -150,6 +149,32 @@ class X: ) +def test_sorted_methods(): + from tensordict.tensorclass import ( + _FALLBACK_METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + ) + + lists_to_check = [ + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD, + ] + # Check that each list is sorted and has unique elements + for lst in lists_to_check: + assert lst == sorted(lst), f"List {lst} is not sorted" + assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements" + # Check that no two lists share any elements + for i, lst1 in enumerate(lists_to_check): + for j, lst2 in enumerate(lists_to_check): + if i != j: + shared_elements = set(lst1) & set(lst2) + assert ( + not shared_elements + ), f"Lists {lst1} and {lst2} share elements: {shared_elements}" def _make_data(shape):