Skip to content

Commit

Permalink
[Performance] Faster copy of TDParams (#1096)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 21, 2024
1 parent f606f7b commit bbfe8c7
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def new_func(self, *args, **kwargs):
if out is self._param_td:
return self
if not isinstance(out, TensorDictParams):
out = TensorDictParams(out, no_convert=True)
out = TensorDictParams(out, no_convert="skip")
out.no_convert = self.no_convert
return out

Expand Down Expand Up @@ -328,11 +328,12 @@ def __init__(
parameters = parameters._param_td
self._param_td = parameters
self.no_convert = no_convert
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
if no_convert != "skip":
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
self._lock_content = lock
if lock:
self._param_td.lock_()
Expand All @@ -341,6 +342,12 @@ def __init__(
self._locked_tensordicts = []
self._get_post_hook = []

@classmethod
def _new_unsafe(
cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
):
return TensorDictParams(parameters, no_convert="skip", lock=lock)

def __iter__(self):
yield from self._param_td.__iter__()

Expand Down Expand Up @@ -613,7 +620,7 @@ def _clone(self, recurse: bool = True) -> TensorDictBase:
"""
if not recurse:
return TensorDictParams(self._param_td._clone(False), no_convert=True)
return TensorDictParams(self._param_td._clone(False), no_convert="skip")

memo = {}

Expand All @@ -631,7 +638,7 @@ def _clone(tensor, memo=memo):
memo[tensor] = result
return result

return TensorDictParams(self._param_td.apply(_clone), no_convert=True)
return TensorDictParams(self._param_td.apply(_clone), no_convert="skip")

@_fallback
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: ...
Expand Down

0 comments on commit bbfe8c7

Please sign in to comment.