Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 9, 2025
2 parents 95b8732 + 73b686d commit cbd7a68
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 69 deletions.
266 changes: 212 additions & 54 deletions docs/source/overview.rst

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7946,7 +7946,7 @@ def reduce(
async_op=False,
return_premature=False,
group=None,
):
) -> None:
"""Reduces the tensordict across all machines.

Only the process with ``rank`` dst is going to receive the final result.
Expand Down Expand Up @@ -9036,7 +9036,7 @@ def newfn(item_and_out):
return out

# Stream
def record_stream(self, stream: torch.cuda.Stream):
def record_stream(self, stream: torch.cuda.Stream) -> T:
"""Marks the tensordict as having been used by this stream.

When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work
Expand Down Expand Up @@ -11353,7 +11353,7 @@ def copy(self):
"""
return self.clone(recurse=False)

def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None):
def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T:
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Args:
Expand Down Expand Up @@ -12438,7 +12438,7 @@ def split_keys(
default: Any = NO_DEFAULT,
strict: bool = True,
reproduce_struct: bool = False,
):
) -> Tuple[T, ...]:
"""Splits the tensordict in subsets given one or more set of keys.

The method will return ``N+1`` tensordicts, where ``N`` is the number of
Expand Down
9 changes: 6 additions & 3 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,12 @@ def forward(
kwargs = {"aggregate_probabilities": False}
log_prob = dist.log_prob(out_tensors, **kwargs)
if log_prob is not out_tensors:
# Composite dists return the tensordict_out directly when aggrgate_prob is False
out_tensors.set(self.log_prob_key, log_prob)
else:
if is_tensor_collection(log_prob):
out_tensors.update(log_prob)
else:
# Composite dists return the tensordict_out directly when aggrgate_prob is False
out_tensors.set(self.log_prob_key, log_prob)
elif dist.log_prob_key in out_tensors:
out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key)
tensordict_out.update(out_tensors)
else:
Expand Down
120 changes: 113 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __subclasscheck__(self, subclass):
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"dumps",
"load_",
"memmap",
"memmap_",
Expand All @@ -145,21 +146,48 @@ def __subclasscheck__(self, subclass):
"_items_list",
"_maybe_names",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"bytes",
"cat_tensors",
"data_ptr",
"depth",
"dim",
"dtype",
"entry_class",
"get_item_shape",
"get_non_tensor",
"irecv",
"is_consolidated",
"is_contiguous",
"is_cpu",
"is_cuda",
"is_empty",
"is_floating_point",
"is_memmap",
"is_meta",
"is_shared",
"isend",
"items",
"keys",
"make_memmap",
"make_memmap_from_tensor",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"saved_path",
"send",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]
Expand Down Expand Up @@ -214,9 +242,6 @@ def __subclasscheck__(self, subclass):
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
Expand All @@ -235,6 +260,8 @@ def __subclasscheck__(self, subclass):
"addcmul",
"addcmul_",
"all",
"amax",
"amin",
"any",
"apply",
"apply_",
Expand All @@ -245,31 +272,43 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bfloat16",
"bitwise_and",
"bool",
"cat",
"cat_from_tensordict",
"ceil",
"ceil_",
"chunk",
"clamp",
"clamp_max",
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"complex128",
"complex32",
"complex64",
"consolidate",
"contiguous",
"copy_",
"copy_at_",
"cos",
"cos_",
"cosh",
"cosh_",
"cpu",
"create_nested",
"cuda",
"cummax",
"cummin",
"densify",
"detach",
"detach_",
"div",
"div_",
"double",
"empty",
"erf",
"erf_",
Expand All @@ -282,20 +321,43 @@ def __subclasscheck__(self, subclass):
"expand_as",
"expm1",
"expm1_",
"fill_",
"filter_empty_",
"filter_non_tensor_data",
"flatten",
"flatten_keys",
"float",
"float16",
"float32",
"float64",
"floor",
"floor_",
"frac",
"frac_",
"from_any",
"from_consolidated",
"from_dataclass",
"from_h5",
"from_modules",
"from_namedtuple",
"from_pytree",
"from_struct_array",
"from_tuple",
"fromkeys",
"gather",
"gather_and_stack",
"half",
"int",
"int16",
"int32",
"int64",
"int8",
"isfinite",
"isnan",
"isneginf",
"isposinf",
"isreal",
"lazy_stack",
"lerp",
"lerp_",
"lgamma",
Expand All @@ -312,13 +374,16 @@ def __subclasscheck__(self, subclass):
"log_",
"logical_and",
"logsumexp",
"make_memmap_from_storage",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
"masked_select",
"max",
"maximum",
"maximum_",
"maybe_dense_stack",
"mean",
"min",
"minimum",
Expand All @@ -338,13 +403,22 @@ def __subclasscheck__(self, subclass):
"norm",
"permute",
"pin_memory",
"pin_memory_",
"popitem",
"pow",
"pow_",
"prod",
"qint32",
"qint8",
"quint4x2",
"quint8",
"reciprocal",
"reciprocal_",
"record_stream",
"refine_names",
"rename",
"rename_", # TODO: must be specialized
"rename_key_",
"repeat",
"repeat_interleave",
"replace",
Expand All @@ -353,6 +427,10 @@ def __subclasscheck__(self, subclass):
"round",
"round_",
"select",
"separates",
"set_",
"set_non_tensor",
"setdefault",
"sigmoid",
"sigmoid_",
"sign",
Expand All @@ -363,9 +441,13 @@ def __subclasscheck__(self, subclass):
"sinh_",
"softmax",
"split",
"split_keys",
"sqrt",
"sqrt_",
"squeeze",
"stack",
"stack_from_tensordict",
"stack_tensors",
"std",
"sub",
"sub_",
Expand All @@ -375,13 +457,21 @@ def __subclasscheck__(self, subclass):
"tanh",
"tanh_",
"to",
"to_h5",
"to_module",
"to_namedtuple",
"to_padded_tensor",
"to_pytree",
"transpose",
"trunc",
"trunc_",
"type",
"uint16",
"uint32",
"uint64",
"uint8",
"unflatten",
"unflatten_keys",
"unlock_",
"unsqueeze",
"var",
Expand All @@ -390,10 +480,6 @@ def __subclasscheck__(self, subclass):
"zero_",
"zero_grad",
]
assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set(
_METHOD_FROM_TD
).intersection(_FALLBACK_METHOD_FROM_TD)
assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD)

# These methods require a copy of the non tensor data
_FALLBACK_METHOD_FROM_TD_COPY = [
Expand Down Expand Up @@ -865,6 +951,14 @@ def __torch_function__(
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
Expand Down Expand Up @@ -2160,6 +2254,18 @@ def _batch_size(self) -> torch.Size:
return self._tensordict.batch_size


def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims


def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad


def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.
Expand Down
Loading

0 comments on commit cbd7a68

Please sign in to comment.