diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 7234a42bd..be8aa42f1 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -5,7 +5,7 @@ from __future__ import annotations import copyreg -from multiprocessing.reduction import ForkingPickler +from multiprocessing import reduction import torch from tensordict._lazy import LazyStackedTensorDict @@ -160,10 +160,10 @@ def _reduce_td(data: TensorDict): # return (_rebuild_tensordict_files, (flat_key_values, metadata_dict)) -ForkingPickler.register(TensorDict, _reduce_td) +reduction.register(TensorDict, _reduce_td) copyreg.pickle(TensorDict, _reduce_td) -ForkingPickler.register(LazyStackedTensorDict, _reduce_td) +reduction.register(LazyStackedTensorDict, _reduce_td) copyreg.pickle(LazyStackedTensorDict, _reduce_td)