From 1f927fadd8e450b531de0e3f76e8ec4df9f202b0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 16:42:32 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/tensorclass.pyi | 53 +++++++++++++++++++++++++++++++- test/test_tensorclass.py | 63 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 87a7d7753..f4131802d 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -526,7 +526,15 @@ class TensorClass: return_early: bool = False, share_non_tensor: bool = False, ) -> T: ... - dumps = save + def dumps( + self, + prefix: str | None = None, + copy_existing: bool = False, + *, + num_threads: int = 0, + return_early: bool = False, + share_non_tensor: bool = False, + ) -> T: ... def memmap( self, prefix: str | None = None, @@ -892,6 +900,14 @@ class TensorClass: *, default: str | CompatibleType | None = None, ) -> T: ... + def clamp( + self, + min: TensorDictBase | torch.Tensor = None, + max: TensorDictBase | torch.Tensor = None, + *, + out=None, + ): ... + def logsumexp(self, dim=None, keepdim=False, *, out=None): ... def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: ... def clamp_max( self, @@ -944,6 +960,27 @@ class TensorClass: def to_namedtuple(self, dest_cls: type | None = None): ... @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): ... + def from_tuple( + cls, + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ): ... + def logical_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: ... + def bitwise_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: ... @classmethod def from_struct_array( cls, struct_array: np.ndarray, device: torch.device | None = None @@ -987,6 +1024,20 @@ class TensorClass: strict: bool = True, reproduce_struct: bool = False, ): ... + def separates( + self, + *keys: NestedKey, + default: Any = NO_DEFAULT, + strict: bool = True, + filter_empty: bool = True, + ) -> T: ... + def norm( + self, + *, + out=None, + dtype: torch.dtype | None = None, + ): ... + def softmax(self, dim: int, dtype: torch.dtype | None = None): ... @property def is_locked(self) -> bool: ... @is_locked.setter diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index e2ccc9989..d9d4be692 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -5,10 +5,12 @@ from __future__ import annotations import argparse +import ast import contextlib import dataclasses import inspect import os +import pathlib import pickle import re import sys @@ -61,6 +63,67 @@ ] +def _get_methods_from_pyi(file_path): + """ + Reads a .pyi file and returns a set of method names. + + Args: + file_path (str): Path to the .pyi file. + + Returns: + set: A set of method names. + """ + with open(file_path, "r") as f: + tree = ast.parse(f.read()) + + methods = set() + for node in tree.body: + if isinstance(node, ast.ClassDef): + for child_node in node.body: + if isinstance(child_node, ast.FunctionDef): + methods.add(child_node.name) + + return methods + + +def _get_methods_from_class(cls): + """ + Returns a set of method names from a given class. + + Args: + cls (class): The class to get methods from. + + Returns: + set: A set of method names. + """ + methods = set() + for name in dir(cls): + attr = getattr(cls, name) + if inspect.isfunction(attr) or inspect.ismethod(attr): + methods.add(name) + + return methods + + +def test_tensorclass_stub_methods(): + tensorclass_pyi_path = ( + pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi" + ) + tensorclass_methods = _get_methods_from_pyi(str(tensorclass_pyi_path)) + + from tensordict import TensorDict + + tensordict_methods = _get_methods_from_class(TensorDict) + + missing_methods = tensordict_methods - tensorclass_methods + missing_methods = [ + method for method in missing_methods if (not method.startswith("_")) + ] + + if missing_methods: + raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") + + def _make_data(shape): return MyData( X=torch.rand(*shape),