diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 6d8883393d5..276c706baef 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -5,7 +5,7 @@ """ This script reproduces the Proximal Policy Optimization (PPO) Algorithm -results from Schulman et al. 2017 for the on Atari Environments. +results from Schulman et al. 2017 for the Atari Environments. """ import hydra from torchrl._utils import logger as torchrl_logger diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 22aec1cbb0d..1c473d31297 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -516,7 +516,7 @@ def append_transform( self, transform: "Transform" # noqa: F821 | Callable[[TensorDictBase], TensorDictBase], - ) -> None: + ) -> EnvBase: """Returns a transformed environment where the callable/transform passed is applied. Args: @@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None: # Single-env specs can be used to remove the batch size from the spec @property - def batch_dims(self): + def batch_dims(self) -> int: + """Number of batch dimensions of the env.""" return len(self.batch_size) def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: @@ -2444,11 +2445,11 @@ def rollout( set_truncated: bool = False, out=None, trust_policy: bool = False, - ): + ) -> TensorDictBase: """Executes a rollout in the environment. - The function will stop as soon as one of the contained environments - returns done=True. + The function will return as soon as any of the contained environments + reaches any of the done states. Args: max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if @@ -2464,14 +2465,16 @@ def rollout( the call to ``rollout``. Keyword Args: - auto_reset (bool, optional): if ``True``, resets automatically the environment - if it is in a done state when the rollout is initiated. - Default is ``True``. + auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the + rollout. If ``False``, then the rollout will continue from a previous state, which requires the + ``tensordict`` argument to be passed with the previous rollout. Default is ``True``. auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the policy device before the policy is used. Default is ``False``. - break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is - called on the sub-envs that are done. Default is True. - break_when_all_done (bool): TODO + break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the + done states. If ``False``, then the done environments are reset automatically. Default is ``True``. + break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any + of the done states. If ``False``, break if at least one environment reaches any of the done states. + Default is ``False``. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the