diff --git a/tensordict/utils.py b/tensordict/utils.py index 0e370856f..f282edd53 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -22,7 +22,6 @@ from collections.abc import KeysView from contextlib import nullcontext from copy import copy -from dataclasses import _FIELDS, GenericAlias from functools import wraps from importlib import import_module from numbers import Number @@ -75,6 +74,18 @@ if TYPE_CHECKING: from tensordict.tensordict import TensorDictBase +try: + from dataclasses import _FIELDS, GenericAlias +except ImportError: + # python < 3.9 + from dataclasses import _FIELDS + + class GenericAlias: + """Placeholder.""" + + ... + + try: try: from torch._C._functorch import ( # @manual=fbcode//caffe2:torch diff --git a/test/smoke_test.py b/test/smoke_test.py index deb3684a8..d3c6a8a06 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -2,8 +2,16 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse + +import pytest def test_imports(): from tensordict import TensorDict # noqa: F401 from tensordict.nn import TensorDictModule # noqa: F401 + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)