Skip to content

Commit

Permalink
[BE] tensorclass method registration check
Browse files Browse the repository at this point in the history
ghstack-source-id: feed6dd1a9f3cd97d3ef31d0b848916f21d98bed
Pull Request resolved: #1175
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent 7f100f3 commit 02ab260
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 12 deletions.
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7938,7 +7938,7 @@ def reduce(
async_op=False,
return_premature=False,
group=None,
):
) -> None:
"""Reduces the tensordict across all machines.

Only the process with ``rank`` dst is going to receive the final result.
Expand Down Expand Up @@ -9028,7 +9028,7 @@ def newfn(item_and_out):
return out

# Stream
def record_stream(self, stream: torch.cuda.Stream):
def record_stream(self, stream: torch.cuda.Stream) -> T:
"""Marks the tensordict as having been used by this stream.

When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work
Expand Down Expand Up @@ -11345,7 +11345,7 @@ def copy(self):
"""
return self.clone(recurse=False)

def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None):
def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T:
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Args:
Expand Down Expand Up @@ -12430,7 +12430,7 @@ def split_keys(
default: Any = NO_DEFAULT,
strict: bool = True,
reproduce_struct: bool = False,
):
) -> Tuple[T, ...]:
"""Splits the tensordict in subsets given one or more set of keys.

The method will return ``N+1`` tensordicts, where ``N`` is the number of
Expand Down
115 changes: 108 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __subclasscheck__(self, subclass):
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"dumps",
"load_",
"memmap",
"memmap_",
Expand All @@ -143,21 +144,48 @@ def __subclasscheck__(self, subclass):
"_items_list",
"_maybe_names",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"bytes",
"cat_tensors",
"data_ptr",
"depth",
"dim",
"dtype",
"entry_class",
"get_item_shape",
"get_non_tensor",
"irecv",
"is_consolidated",
"is_contiguous",
"is_cpu",
"is_cuda",
"is_empty",
"is_floating_point",
"is_memmap",
"is_meta",
"is_shared",
"isend",
"items",
"keys",
"make_memmap",
"make_memmap_from_tensor",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"saved_path",
"send",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]
Expand Down Expand Up @@ -212,9 +240,6 @@ def __subclasscheck__(self, subclass):
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
Expand All @@ -233,6 +258,8 @@ def __subclasscheck__(self, subclass):
"addcmul",
"addcmul_",
"all",
"amax",
"amin",
"any",
"apply",
"apply_",
Expand All @@ -243,31 +270,43 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bfloat16",
"bitwise_and",
"bool",
"cat",
"cat_from_tensordict",
"ceil",
"ceil_",
"chunk",
"clamp",
"clamp_max",
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"complex128",
"complex32",
"complex64",
"consolidate",
"contiguous",
"copy_",
"copy_at_",
"cos",
"cos_",
"cosh",
"cosh_",
"cpu",
"create_nested",
"cuda",
"cummax",
"cummin",
"densify",
"detach",
"detach_",
"div",
"div_",
"double",
"empty",
"erf",
"erf_",
Expand All @@ -280,20 +319,43 @@ def __subclasscheck__(self, subclass):
"expand_as",
"expm1",
"expm1_",
"fill_",
"filter_empty_",
"filter_non_tensor_data",
"flatten",
"flatten_keys",
"float",
"float16",
"float32",
"float64",
"floor",
"floor_",
"frac",
"frac_",
"from_any",
"from_consolidated",
"from_dataclass",
"from_h5",
"from_modules",
"from_namedtuple",
"from_pytree",
"from_struct_array",
"from_tuple",
"fromkeys",
"gather",
"gather_and_stack",
"half",
"int",
"int16",
"int32",
"int64",
"int8",
"isfinite",
"isnan",
"isneginf",
"isposinf",
"isreal",
"lazy_stack",
"lerp",
"lerp_",
"lgamma",
Expand All @@ -310,13 +372,16 @@ def __subclasscheck__(self, subclass):
"log_",
"logical_and",
"logsumexp",
"make_memmap_from_storage",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
"masked_select",
"max",
"maximum",
"maximum_",
"maybe_dense_stack",
"mean",
"min",
"minimum",
Expand All @@ -336,13 +401,22 @@ def __subclasscheck__(self, subclass):
"norm",
"permute",
"pin_memory",
"pin_memory_",
"popitem",
"pow",
"pow_",
"prod",
"qint32",
"qint8",
"quint4x2",
"quint8",
"reciprocal",
"reciprocal_",
"record_stream",
"refine_names",
"rename",
"rename_", # TODO: must be specialized
"rename_key_",
"repeat",
"repeat_interleave",
"replace",
Expand All @@ -351,6 +425,10 @@ def __subclasscheck__(self, subclass):
"round",
"round_",
"select",
"separates",
"set_",
"set_non_tensor",
"setdefault",
"sigmoid",
"sigmoid_",
"sign",
Expand All @@ -361,9 +439,13 @@ def __subclasscheck__(self, subclass):
"sinh_",
"softmax",
"split",
"split_keys",
"sqrt",
"sqrt_",
"squeeze",
"stack",
"stack_from_tensordict",
"stack_tensors",
"std",
"sub",
"sub_",
Expand All @@ -373,13 +455,21 @@ def __subclasscheck__(self, subclass):
"tanh",
"tanh_",
"to",
"to_h5",
"to_module",
"to_namedtuple",
"to_padded_tensor",
"to_pytree",
"transpose",
"trunc",
"trunc_",
"type",
"uint16",
"uint32",
"uint64",
"uint8",
"unflatten",
"unflatten_keys",
"unlock_",
"unsqueeze",
"var",
Expand All @@ -388,10 +478,6 @@ def __subclasscheck__(self, subclass):
"zero_",
"zero_grad",
]
assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set(
_METHOD_FROM_TD
).intersection(_FALLBACK_METHOD_FROM_TD)
assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD)

# These methods require a copy of the non tensor data
_FALLBACK_METHOD_FROM_TD_COPY = [
Expand Down Expand Up @@ -863,6 +949,14 @@ def __torch_function__(
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
Expand Down Expand Up @@ -2157,6 +2251,13 @@ def _batch_size(self) -> torch.Size:
"""
return self._tensordict.batch_size

def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims
def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad
def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.
Expand Down
36 changes: 35 additions & 1 deletion test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _get_methods_from_class(cls):
methods = set()
for name in dir(cls):
attr = getattr(cls, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
if inspect.isfunction(attr) or inspect.ismethod(attr) or isinstance(attr, property):
methods.add(name)

return methods
Expand All @@ -122,6 +122,34 @@ def test_tensorclass_stub_methods():

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)


def test_tensorclass_instance_methods():
@tensorclass
class X:
x: torch.Tensor

tensorclass_pyi_path = (
pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi"
)
tensorclass_abstract_methods = _get_methods_from_pyi(str(tensorclass_pyi_path))

tensorclass_methods = _get_methods_from_class(X)

missing_methods = tensorclass_abstract_methods - tensorclass_methods - {"data", "grad"}
missing_methods = [
method for method in missing_methods if (not method.startswith("_"))
]

if missing_methods:
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)




def _make_data(shape):
Expand Down Expand Up @@ -188,6 +216,12 @@ class MyClass1:
MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]),
batch_size=[3, 1],
)
assert x.shape == x.batch_size
assert x.batch_size == (3, 1)
assert x.ndim == 2
assert x.batch_dims == 2
assert x.numel() == 3

assert not x.all()
assert not x.any()
assert isinstance(x.all(), bool)
Expand Down

0 comments on commit 02ab260

Please sign in to comment.