diff --git a/README.md b/README.md index 59f4e880..f54ae854 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,12 @@ We hope this allows us to provide reliable implementations following stable-base See documentation for the full list of included features. **RL Algorithms**: -- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) +- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) +- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/) +- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) -- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 4b2cfb43..234e6f8c 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,14 +9,17 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ ARS ✔️ ❌️ ❌ ❌ ✔️ +MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️ QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ +RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️ TQC ✔️ ❌ ❌ ❌ ✔️ TRPO ✔️ ✔️ ✔️ ✔️ ✔️ ============ =========== ============ ================= =============== ================ .. note:: - Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. + ``Tuple`` observation spaces are not supported by any environment, + however, single-level ``Dict`` spaces are Actions ``gym.spaces``: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index d33d4a94..cd4851d0 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -71,3 +71,38 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ model = ARS("LinearPolicy", "Pendulum-v1", verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("ars_pendulum") + +RecurrentPPO +------------ + +Train a PPO agent with a recurrent policy on the CartPole environment. + + +.. note:: + + It is particularly important to pass the ``lstm_states`` + and ``episode_start`` argument to the ``predict()`` method, + so the cell and hidden states of the LSTM are correctly updated. + + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) + model.learn(5000) + + env = model.get_env() + obs = env.reset() + # cell and hidden state of the LSTM + lstm_states = None + num_envs = 1 + # Episode start signals are used to reset the lstm states + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) + obs, rewards, dones, info = env.step(action) + episode_starts = dones + env.render() diff --git a/docs/index.rst b/docs/index.rst index f713dfaa..5e322652 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d modules/ars modules/ppo_mask + modules/ppo_recurrent modules/qrdqn modules/tqc modules/trpo diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3de9b469..5554ded5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,11 @@ Changelog ========== -Release 1.5.1a7 (WIP) +Release 1.5.1a8 (WIP) ------------------------------- +**Add RecurrentPPO (aka PPO LSTM)** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.5.1a7 @@ -17,6 +19,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``RecurrentPPO`` (aka PPO LSTM) Bug Fixes: ^^^^^^^^^^ @@ -34,7 +37,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -- Allow PPO to turn of advantage normalization (see `PR #61 `_) @vwxyzjn +- Allow PPO to turn of advantage normalization (see `PR #61 `_) (@vwxyzjn) + Bug Fixes: ^^^^^^^^^^ @@ -46,6 +50,9 @@ Deprecations: Others: ^^^^^^^ +Documentation: +^^^^^^^^^^^^^^ + Release 1.4.0 (2022-01-19) ------------------------------- **Add Trust Region Policy Optimization (TRPO) and Augmented Random Search (ARS) algorithms** diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 9580ff38..a43f5969 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -5,7 +5,7 @@ Maskable PPO ============ -Implementation of `invalid action masking `_ for the Proximal Policy Optimization(PPO) +Implementation of `invalid action masking `_ for the Proximal Policy Optimization (PPO) algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm. diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst new file mode 100644 index 00000000..bc819c39 --- /dev/null +++ b/docs/modules/ppo_recurrent.rst @@ -0,0 +1,153 @@ +.. _ppo_lstm: + +.. automodule:: sb3_contrib.ppo_recurrent + +Recurrent PPO +============= + +Implementation of recurrent policies for the Proximal Policy Optimization (PPO) +algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm. + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpLstmPolicy + CnnLstmPolicy + MultiInputLstmPolicy + + +Notes +----- + +- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/ + + +Can I use? +---------- + +- Recurrent policies: ✔️ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ✔️ ✔️ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. note:: + + It is particularly important to pass the ``lstm_states`` + and ``episode_start`` argument to the ``predict()`` method, + so the cell and hidden states of the LSTM are correctly updated. + + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + from stable_baselines3.common.evaluation import evaluate_policy + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) + model.learn(5000) + + env = model.get_env() + mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False) + print(mean_reward) + + model.save("ppo_recurrent") + del model # remove to demonstrate saving and loading + + model = RecurrentPPO.load("ppo_recurrent") + + obs = env.reset() + # cell and hidden state of the LSTM + lstm_states = None + num_envs = 1 + # Episode start signals are used to reset the lstm states + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) + obs, rewards, dones, info = env.step(action) + episode_starts = dones + env.render() + + + +Results +------- + +Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 + +``RecurrentPPO`` was evaluated against PPO on: + +- PendulumNoVel-v1 +- LunarLanderNoVel-v2 +- CartPoleNoVel-v1 +- MountainCarContinuousNoVel-v0 +- CarRacing-v0 + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the repo for the experiment: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo + git checkout feat/recurrent-ppo + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Parameters +---------- + +.. autoclass:: RecurrentPPO + :members: + :inherited-members: + + +RecurrentPPO Policies +--------------------- + +.. autoclass:: MlpLstmPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy + :members: + :noindex: + +.. autoclass:: CnnLstmPolicy + :members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy + :members: + :noindex: + +.. autoclass:: MultiInputLstmPolicy + :members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy + :members: + :noindex: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 6ab21a10..1836ac48 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -2,6 +2,7 @@ from sb3_contrib.ars import ARS from sb3_contrib.ppo_mask import MaskablePPO +from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index e521c343..a26f6fbb 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -215,12 +215,12 @@ def predict( action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :param action_masks: Action masks to apply to the action distribution :return: the model's action and the next state @@ -229,8 +229,8 @@ def predict( # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) @@ -256,7 +256,7 @@ def predict( raise ValueError("Error: The environment must be vectorized when using recurrent policies.") actions = actions[0] - return actions, state + return actions, None def evaluate_actions( self, diff --git a/sb3_contrib/common/recurrent/__init__.py b/sb3_contrib/common/recurrent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py new file mode 100644 index 00000000..88ff4254 --- /dev/null +++ b/sb3_contrib/common/recurrent/buffers.py @@ -0,0 +1,384 @@ +from functools import partial +from typing import Callable, Generator, Optional, Tuple, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.vec_env import VecNormalize + +from sb3_contrib.common.recurrent.type_aliases import ( + RecurrentDictRolloutBufferSamples, + RecurrentRolloutBufferSamples, + RNNStates, +) + + +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq, max_length, *tensor_shape) + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length,) aka (padded_batch_size,) + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +) -> Tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) + """ + # Create sequence if env changes too + seq_start = np.logical_or(episode_starts, env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten + + +class RecurrentRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the LSTM cell and hidden states. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + (n_steps, lstm.num_layers, n_envs, lstm.hidden_size) + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_layers = self.hidden_states_pi.shape[1] + # Number of sequences + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + ) + lstm_states_vf = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + + return RecurrentRolloutBufferSamples( + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) + + +class RecurrentDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RecurrentRolloutBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) + + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentDictRolloutBufferSamples: + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) + + n_layers = self.hidden_states_pi.shape[1] + n_seq = len(self.seq_start_indices) + max_length = self.pad(self.actions[batch_inds]).shape[1] + padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence + lstm_states_pi = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + ) + lstm_states_vf = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} + + return RecurrentDictRolloutBufferSamples( + observations=observations, + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), + ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py new file mode 100644 index 00000000..16f1c200 --- /dev/null +++ b/sb3_contrib/common/recurrent/policies.py @@ -0,0 +1,601 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict +from torch import nn + +from sb3_contrib.common.recurrent.type_aliases import RNNStates + + +class RecurrentActorCriticPolicy(ActorCriticPolicy): + """ + Recurrent policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. + It assumes that both the actor and the critic LSTM + have the same architecture. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic + (in that case, only the actor gradient is used) + By default, the actor and the critic have two separate LSTM. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + self.lstm_output_dim = lstm_hidden_size + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + self.lstm_kwargs = lstm_kwargs or {} + self.shared_lstm = shared_lstm + self.enable_critic_lstm = enable_critic_lstm + self.lstm_actor = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) + # For the predict() method, to initialize hidden states + # (n_lstm_layers, batch_size, lstm_hidden_size) + self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size) + self.critic = None + self.lstm_critic = None + assert not ( + self.shared_lstm and self.enable_critic_lstm + ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + + # No LSTM for the critic, we still need to convert + # output of features extractor to the correct size + # (size of the output of the actor lstm) + if not (self.shared_lstm or self.enable_critic_lstm): + self.critic = nn.Linear(self.features_dim, lstm_hidden_size) + + # Use a separate LSTM for the critic + if self.enable_critic_lstm: + self.lstm_critic = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + self.mlp_extractor = MlpExtractor( + self.lstm_output_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + @staticmethod + def _process_sequence( + features: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + lstm: nn.LSTM, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Do a forward pass in the LSTM network. + + :param features: Input tensor + :param lstm_states: previous cell and hidden states of the LSTM + :param episode_starts: Indicates when a new episode starts, + in that case, we need to reset LSTM states. + :param lstm: LSTM object. + :return: LSTM output and updated LSTM states. + """ + # LSTM logic + # (sequence length, batch size, features dim) + # (batch size = n_envs for data collection or n_seq when doing gradient update) + n_seq = lstm_states[0].shape[1] + # Batch to sequence + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + # note: max length (max sequence length) is always 1 during data collection + features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) + + # If we don't have to reset the state in the middle of a sequence + # we can avoid the for loop, which speeds up things + if th.all(episode_starts == 0.0): + lstm_output, lstm_states = lstm(features_sequence, lstm_states) + lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + lstm_output = [] + # Iterate over the sequence + for features, episode_start in zip_strict(features_sequence, episode_starts): + hidden, lstm_states = lstm( + features.unsqueeze(dim=0), + ( + # Reset the states at the beginning of a new episode + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], + ), + ) + lstm_output += [hidden] + # Sequence to batch + # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + def forward( + self, + obs: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation. Observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Re-use LSTM features but do not backpropagate + latent_vf = latent_pi.detach() + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + # Critic only has a feedforward network + latent_vf = self.critic(features) + lstm_states_vf = lstm_states_pi + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) + + def get_distribution( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Tuple[Distribution, Tuple[th.Tensor, ...]]: + """ + Get the current policy distribution given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the action distribution and new hidden states. + """ + features = self.extract_features(obs) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + return self._get_action_dist_from_latent(latent_pi), lstm_states + + def predict_values( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the estimated values. + """ + features = self.extract_features(obs) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Use LSTM from the actor + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + + if self.lstm_critic is not None: + latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def _predict( + self, + observation: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy and hidden states of the RNN + """ + distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts) + return distribution.get_actions(deterministic=deterministic), lstm_states + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[list(observation.keys())[0]].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if state is None: + # Initialize hidden states to zeros + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) + state = (state, state) + + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) + + with th.no_grad(): + # Convert to PyTorch tensors + states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) + episode_starts = th.tensor(episode_start).float().to(self.device) + actions, states = self._predict( + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + ) + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions[0] + + return actions, states + + +class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): + """ + CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + enable_critic_lstm, + lstm_kwargs, + ) + + +class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + enable_critic_lstm: bool = True, + lstm_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + enable_critic_lstm, + lstm_kwargs, + ) diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py new file mode 100644 index 00000000..1ae9a087 --- /dev/null +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -0,0 +1,33 @@ +from typing import NamedTuple, Tuple + +import torch as th +from stable_baselines3.common.type_aliases import TensorDict + + +class RNNStates(NamedTuple): + pi: Tuple[th.Tensor, ...] + vf: Tuple[th.Tensor, ...] + + +class RecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor + + +class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + mask: th.Tensor diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 2f70ce0c..f78022a0 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -392,14 +392,17 @@ def predict( action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the model's action(s) from an observation. + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :param action_masks: Action masks to apply to the action distribution. - :return: the model's action and the next state (used in recurrent policies) + :return: the model's action and the next hidden state + (used in recurrent policies) """ return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks) diff --git a/sb3_contrib/ppo_recurrent/__init__.py b/sb3_contrib/ppo_recurrent/__init__.py new file mode 100644 index 00000000..3fb5436e --- /dev/null +++ b/sb3_contrib/ppo_recurrent/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy +from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO diff --git a/sb3_contrib/ppo_recurrent/policies.py b/sb3_contrib/ppo_recurrent/policies.py new file mode 100644 index 00000000..d9b37458 --- /dev/null +++ b/sb3_contrib/ppo_recurrent/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.recurrent.policies import ( + RecurrentActorCriticCnnPolicy, + RecurrentActorCriticPolicy, + RecurrentMultiInputActorCriticPolicy, +) + +MlpLstmPolicy = RecurrentActorCriticPolicy +CnnLstmPolicy = RecurrentActorCriticCnnPolicy +MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py new file mode 100644 index 00000000..f0920f9d --- /dev/null +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -0,0 +1,530 @@ +import time +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv + +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy +from sb3_contrib.common.recurrent.type_aliases import RNNStates +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy + + +class RecurrentPPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + with support for recurrent policies (LSTM). + + Based on the original Stable Baselines 3 implementation. + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpLstmPolicy": MlpLstmPolicy, + "CnnLstmPolicy": CnnLstmPolicy, + "MultiInputLstmPolicy": MultiInputLstmPolicy, + } + + def __init__( + self, + policy: Union[str, Type[RecurrentActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 128, + batch_size: Optional[int] = 128, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + create_eval_env=create_eval_env, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + self._last_lstm_states = None + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = ( + RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer + ) + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + # We assume that LSTM for the actor and the critic + # have the same architecture + lstm = self.policy.lstm_actor + + if not isinstance(self.policy, RecurrentActorCriticPolicy): + raise ValueError("Policy must subclass RecurrentActorCriticPolicy") + + single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) + # hidden and cell states for actor and critic + self._last_lstm_states = RNNStates( + ( + th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape).to(self.device), + ), + ( + th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape).to(self.device), + ), + ) + + hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + hidden_state_buffer_shape, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: MaybeCallback = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "RecurrentPPO", + ) -> Tuple[int, BaseCallback]: + """ + Initialize different variables needed for training. + + :param total_timesteps: The total number of samples (env steps) to train on + :param eval_env: Environment to use for evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations + :param n_eval_episodes: How many episodes to play per evaluation + :param log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute + :param tb_log_name: the name of the run for tensorboard log + :return: + """ + + total_timesteps, callback = super()._setup_learn( + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + log_path, + reset_num_timesteps, + tb_log_name, + ) + return total_timesteps, callback + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert isinstance( + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + ), f"{rollout_buffer} doesn't support recurrent policy" + + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + lstm_states = deepcopy(self._last_lstm_states) + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) + actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_lstm_state = ( + lstm_states.vf[0][:, idx : idx + 1, :], + lstm_states.vf[1][:, idx : idx + 1, :], + ) + # terminal_lstm_state = None + episode_starts = th.tensor([False]).float().to(self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + lstm_states=self._last_lstm_states, + ) + + self._last_obs = new_obs + self._last_episode_starts = dones + self._last_lstm_states = lstm_states + + with th.no_grad(): + # Compute value for the last timestep + episode_starts = th.tensor(dones).float().to(self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.lstm_states, + rollout_data.episode_starts, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + # Mask padded sequences + policy_loss_1 = policy_loss_1 * rollout_data.mask + policy_loss_2 = policy_loss_2 * rollout_data.mask + policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2) + + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-(log_prob * rollout_data.mask)) + else: + entropy_loss = -th.mean(entropy * rollout_data.mask) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "RecurrentPPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "RecurrentPPO": + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 1d756441..ac076349 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -218,13 +218,16 @@ def predict( deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Overrides the base_class predict function to include epsilon-greedy exploration. + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e39732bd..511e75b2 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a7 +1.5.1a8 diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index a3795075..45931f0b 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -1,2 +1,2 @@ #!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not slow" diff --git a/setup.cfg b/setup.cfg index 11009480..4a0cf75b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,8 @@ filterwarnings = ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym +markers = + slow: marks tests as slow (deselect with '-m "not slow"') [pytype] inputs = sb3_contrib @@ -24,6 +26,7 @@ per-file-ignores = ./sb3_contrib/__init__.py:F401 ./sb3_contrib/ars/__init__.py:F401 ./sb3_contrib/ppo_mask/__init__.py:F401 + ./sb3_contrib/ppo_recurrent/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 ./sb3_contrib/trpo/__init__.py:F401 diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 1ba7283e..458d3f06 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -3,7 +3,7 @@ from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC +from sb3_contrib import ARS, QRDQN, TQC, RecurrentPPO from sb3_contrib.common.vec_env import AsyncEval N_STEPS_TRAINING = 500 @@ -11,7 +11,7 @@ ARS_MULTI = "ars_multi" -@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI]) +@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI, RecurrentPPO]) def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] @@ -32,9 +32,12 @@ def test_deterministic_training_common(algo): kwargs.update({"learning_starts": 100, "target_update_interval": 100}) elif algo == ARS: kwargs.update({"n_delta": 2}) - + elif algo == RecurrentPPO: + kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)}) + kwargs.update({"n_steps": 50, "n_epochs": 4}) + policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy" for i in range(2): - model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) + model = algo(policy_str, env_id, seed=SEED, **kwargs) learn_kwargs = {"total_timesteps": N_STEPS_TRAINING} if ars_multi: @@ -46,9 +49,11 @@ def test_deterministic_training_common(algo): model.learn(**learn_kwargs) env = model.get_env() obs = env.reset() + states = None + episode_start = None for _ in range(100): - action, _ = model.predict(obs, deterministic=False) - obs, reward, _, _ = env.step(action) + action, states = model.predict(obs, state=states, episode_start=episode_start, deterministic=False) + obs, reward, episode_start, _ = env.step(action) results[i].append(action) rewards[i].append(reward) assert sum(results[0]) == sum(results[1]), results diff --git a/tests/test_lstm.py b/tests/test_lstm.py new file mode 100644 index 00000000..f0ba3e6a --- /dev/null +++ b/tests/test_lstm.py @@ -0,0 +1,186 @@ +import gym +import numpy as np +import pytest +from gym import spaces +from gym.envs.classic_control import CartPoleEnv +from gym.wrappers.time_limit import TimeLimit +from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import FakeImageEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import VecNormalize + +from sb3_contrib import RecurrentPPO + + +class ToDictWrapper(gym.Wrapper): + """ + Simple wrapper to test MultInputPolicy on Dict obs. + """ + + def __init__(self, env): + super().__init__(env) + self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) + + def reset(self): + return {"obs": self.env.reset()} + + def step(self, action): + obs, reward, done, infos = self.env.step(action) + return {"obs": obs}, reward, done, infos + + +class CartPoleNoVelEnv(CartPoleEnv): + """Variant of CartPoleEnv with velocity information removed. This task requires memory to solve.""" + + def __init__(self): + super().__init__() + high = np.array( + [ + self.x_threshold * 2, + self.theta_threshold_radians * 2, + ] + ) + self.observation_space = spaces.Box(-high, high, dtype=np.float32) + + @staticmethod + def _pos_obs(full_obs): + xpos, _xvel, thetapos, _thetavel = full_obs + return xpos, thetapos + + def reset(self): + full_obs = super().reset() + return CartPoleNoVelEnv._pos_obs(full_obs) + + def step(self, action): + full_obs, rew, done, info = super().step(action) + return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info + + +def test_cnn(): + model = RecurrentPPO( + "CnnLstmPolicy", + FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), + n_steps=16, + seed=0, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + + model.learn(total_timesteps=32) + + +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + ], +) +def test_policy_kwargs(policy_kwargs): + model = RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + model.learn(total_timesteps=32) + + +def test_check(): + policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True) + with pytest.raises(AssertionError): + RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) +def test_run(env): + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=16, + seed=0, + create_eval_env=True, + ) + + model.learn(total_timesteps=32, eval_freq=16) + + +def test_run_sde(): + model = RecurrentPPO( + "MlpLstmPolicy", + "Pendulum-v1", + n_steps=16, + seed=0, + create_eval_env=True, + sde_sample_freq=4, + use_sde=True, + clip_range_vf=0.1, + ) + + model.learn(total_timesteps=200, eval_freq=150) + + +def test_dict_obs(): + env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) + model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64) + evaluate_policy(model, env, warn=False) + + +@pytest.mark.slow +def test_ppo_lstm_performance(): + # env = make_vec_env("CartPole-v1", n_envs=16) + def make_env(): + env = CartPoleNoVelEnv() + env = TimeLimit(env, max_episode_steps=500) + return env + + env = VecNormalize(make_vec_env(make_env, n_envs=8)) + + eval_callback = EvalCallback( + VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False), + n_eval_episodes=20, + eval_freq=5000 // env.num_envs, + ) + + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=128, + learning_rate=0.0007, + verbose=1, + batch_size=256, + seed=1, + n_epochs=10, + max_grad_norm=1, + gae_lambda=0.98, + policy_kwargs=dict( + net_arch=[dict(vf=[64])], + lstm_hidden_size=64, + ortho_init=False, + enable_critic_lstm=True, + ), + ) + + model.learn(total_timesteps=50_000, callback=eval_callback) + # Maximum episode reward is 500. + # In CartPole-v1, a non-recurrent policy can easily get >= 450. + # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. + evaluate_policy(model, env, reward_threshold=450)