Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] TensorClass stub method check #1174

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,15 @@ class TensorClass:
return_early: bool = False,
share_non_tensor: bool = False,
) -> T: ...
dumps = save
def dumps(
self,
prefix: str | None = None,
copy_existing: bool = False,
*,
num_threads: int = 0,
return_early: bool = False,
share_non_tensor: bool = False,
) -> T: ...
def memmap(
self,
prefix: str | None = None,
Expand Down Expand Up @@ -892,6 +900,14 @@ class TensorClass:
*,
default: str | CompatibleType | None = None,
) -> T: ...
def clamp(
self,
min: TensorDictBase | torch.Tensor = None,
max: TensorDictBase | torch.Tensor = None,
*,
out=None,
): ...
def logsumexp(self, dim=None, keepdim=False, *, out=None): ...
def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: ...
def clamp_max(
self,
Expand Down Expand Up @@ -944,6 +960,27 @@ class TensorClass:
def to_namedtuple(self, dest_cls: type | None = None): ...
@classmethod
def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): ...
def from_tuple(
cls,
obj,
*,
auto_batch_size: bool = False,
batch_dims: int | None = None,
device: torch.device | None = None,
batch_size: torch.Size | None = None,
): ...
def logical_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: ...
def bitwise_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: ...
@classmethod
def from_struct_array(
cls, struct_array: np.ndarray, device: torch.device | None = None
Expand Down Expand Up @@ -987,6 +1024,20 @@ class TensorClass:
strict: bool = True,
reproduce_struct: bool = False,
): ...
def separates(
self,
*keys: NestedKey,
default: Any = NO_DEFAULT,
strict: bool = True,
filter_empty: bool = True,
) -> T: ...
def norm(
self,
*,
out=None,
dtype: torch.dtype | None = None,
): ...
def softmax(self, dim: int, dtype: torch.dtype | None = None): ...
@property
def is_locked(self) -> bool: ...
@is_locked.setter
Expand Down
63 changes: 63 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import argparse
import ast
import contextlib
import dataclasses
import inspect
import os
import pathlib
import pickle
import re
import sys
Expand Down Expand Up @@ -61,6 +63,67 @@
]


def _get_methods_from_pyi(file_path):
"""
Reads a .pyi file and returns a set of method names.

Args:
file_path (str): Path to the .pyi file.

Returns:
set: A set of method names.
"""
with open(file_path, "r") as f:
tree = ast.parse(f.read())

methods = set()
for node in tree.body:
if isinstance(node, ast.ClassDef):
for child_node in node.body:
if isinstance(child_node, ast.FunctionDef):
methods.add(child_node.name)

return methods


def _get_methods_from_class(cls):
"""
Returns a set of method names from a given class.

Args:
cls (class): The class to get methods from.

Returns:
set: A set of method names.
"""
methods = set()
for name in dir(cls):
attr = getattr(cls, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
methods.add(name)

return methods


def test_tensorclass_stub_methods():
tensorclass_pyi_path = (
pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi"
)
tensorclass_methods = _get_methods_from_pyi(str(tensorclass_pyi_path))

from tensordict import TensorDict

tensordict_methods = _get_methods_from_class(TensorDict)

missing_methods = tensordict_methods - tensorclass_methods
missing_methods = [
method for method in missing_methods if (not method.startswith("_"))
]

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")


def _make_data(shape):
return MyData(
X=torch.rand(*shape),
Expand Down
Loading