Skip to content

Commit

Permalink
[BugFix] Fix pre 2.1 _apply compatibility
Browse files Browse the repository at this point in the history
ghstack-source-id: b8e890e36e0b15dda039e74004c5ab63af16435b
Pull Request resolved: #1050
  • Loading branch information
vmoens committed Oct 21, 2024
1 parent 75b33c4 commit b2ccfe7
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
_LOCK_ERROR,
BufferLegacy,
erase_cache,
implement_for,
IndexType,
is_batchedtensor,
lock_blocked,
)
from torch import multiprocessing as mp, nn, Tensor
from torch.utils._pytree import tree_map


try:
from functorch import dim as ftdim

Expand Down Expand Up @@ -1160,6 +1160,7 @@ def update_at_(
@_apply_on_data
def apply_(self, fn: Callable, *others, **kwargs) -> T: ...

@implement_for("torch", "2.1")
def _apply(self, fn, recurse=True):
self._param_td._erase_cache()
param_td = self._param_td
Expand All @@ -1179,6 +1180,26 @@ def _apply(self, fn, recurse=True):
out.auto_device_()
return out

@implement_for("torch", None, "2.1")
def _apply(self, fn): # noqa: F811
self._param_td._erase_cache()
param_td = self._param_td
self._param_td = param_td.copy()
# Keep a list of buffers to update .data only
bufs = dict(self._buffers)
out: TensorDictBase = super()._apply(fn)
for key, val in bufs.items():
val.data = self._buffers[key].data
self._buffers[key] = val
# Check device and shape
cbs = out._check_batch_size(raise_exception=False)
if not cbs:
out.auto_batch_size_()
cd = out._check_device(raise_exception=False)
if not cd:
out.auto_device_()
return out


TDPARAM_HANDLED_FUNCTIONS = copy(TD_HANDLED_FUNCTIONS)

Expand Down

0 comments on commit b2ccfe7

Please sign in to comment.