Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecurrentPPO: 9x speedup - whole sequence batching #118

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6467c79
added whole sequence batching functionality to PPORecurrent
b-vm Nov 12, 2022
ff8cb9d
added masking and fixed some bugs
b-vm Nov 14, 2022
cfc0f70
implemented for non dict obs + fixed bugs + added basic script showin…
b-vm Nov 23, 2022
8b60954
bug fix episode starts after first update
b-vm Nov 28, 2022
59f4a7a
updated testing script
b-vm Nov 28, 2022
9196049
fixed NaNs due to supersmall batch sizesoccurring in edge case by dro…
b-vm Nov 28, 2022
a74be36
Merge branch 'master' into sequence_batching
araffin Nov 28, 2022
0f5a0f7
Merge branch 'master' into sequence_batching
araffin Dec 16, 2022
f445d4d
Merge branch 'master' of https://github.com/Stable-Baselines-Team/sta…
b-vm Jan 8, 2023
b83d5f8
Merge branch 'Stable-Baselines-Team-master' into sequence_batching
b-vm Jan 8, 2023
d3af84d
refactoring, made code simpler.
b-vm Jan 8, 2023
18ace01
improved indexing to sample all sequences
b-vm Jan 8, 2023
de092ba
integrated whole sequence train function with existing
b-vm Mar 1, 2023
d66bcaa
Merge pull request #2 from b-vm/master
b-vm Mar 1, 2023
5dc8bc9
improvement of isntance checking
b-vm Mar 1, 2023
9a268be
Merge branch 'master' into sequence_batching
araffin Apr 3, 2023
e543b31
Reformat and simplify
araffin Apr 3, 2023
194f758
bug fix flatten extractor instance check
b-vm Apr 12, 2023
8f18c9c
simplified if statement
b-vm Apr 12, 2023
5e6f371
Merge branch 'master' into sequence_batching
araffin Apr 27, 2023
ad0d3ed
Update comment
araffin Apr 27, 2023
b43e9b5
Re-add drop last, was causing NaN
araffin Apr 27, 2023
ef37cc7
Fix NaN
araffin Apr 27, 2023
a2382c8
fixed bug: append correct index, and only once.
b-vm May 29, 2023
2837be5
Merge pull request #3 from Stable-Baselines-Team/master
b-vm Jun 29, 2023
79fbd3f
Merge branch 'master' into sequence_batching
araffin Sep 3, 2023
7edb731
Merge branch 'master' into sequence_batching
araffin Oct 6, 2023
9898e27
Merge branch 'master' into sequence_batching
araffin Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ src
*.prof

MUJOCO_LOG.TXT

temp/
188 changes: 187 additions & 1 deletion sb3_contrib/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from functools import partial
from typing import Callable, Generator, Optional, Tuple, Union
from typing import Callable, Generator, List, Optional, Tuple, Union

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from sb3_contrib.common.recurrent.type_aliases import (
RecurrentDictRolloutBufferSamples,
RecurrentDictRolloutBufferSequenceSamples,
RecurrentRolloutBufferSamples,
RecurrentRolloutBufferSequenceSamples,
RNNStates,
)

Expand Down Expand Up @@ -94,6 +98,30 @@ def create_sequencers(
return seq_start_indices, local_pad, local_pad_and_flatten


def create_sequence_slicer(
episode_start_indices: np.ndarray, device: Union[th.device, str]
) -> Callable[[np.ndarray, List[str]], th.Tensor]:
def create_sequence_minibatch(tensor: np.ndarray, seq_indices: List[str]) -> th.Tensor:
"""
Create minibatch of whole sequence.

:param tensor: Tensor that will be sliced (e.g. observations, rewards)
:param seq_indices: Sequences to be used.
:return: (max_sequence_length, batch_size=n_seq, features_size)
"""
return pad_sequence(
[
th.tensor(
tensor[episode_start_indices[i] : episode_start_indices[i + 1]],
device=device,
)
for i in seq_indices
]
)

return create_sequence_minibatch


class RecurrentRolloutBuffer(RolloutBuffer):
"""
Rollout buffer that also stores the LSTM cell and hidden states.
Expand Down Expand Up @@ -382,3 +410,161 @@ def _get_samples(
episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]),
mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])),
)


class RecurrentSequenceRolloutBuffer(RecurrentRolloutBuffer):
"""
Sequence Rollout buffer used in on-policy algorithms like A2C/PPO.
Overrides the RecurrentRolloutBuffer to yield 3d batches of whole sequences

:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param hidden_state_shape: Shape of the buffer that will collect lstm states
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
super().__init__(
buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs
)

def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSequenceSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
# Prepare the data
if not self.generator_ready:
self.episode_starts[0, :] = 1
for tensor in [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
"episode_starts",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])

self.episode_start_indices = np.where(self.episode_starts == 1)[0]
self.generator_ready = True

random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)))
# Do not drop last batch so we are sure we sample at least one sequence
# TODO: allow to change that parameter
batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False)
# add a dummy index to make the code below simpler
episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])])

create_minibatch = create_sequence_slicer(episode_start_indices, self.device)

# yields batches of whole sequences, shape: (max_sequence_length, batch_size=n_seq, features_size))
for indices in batch_sampler:
returns_batch = create_minibatch(self.returns, indices)
masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)])

yield RecurrentRolloutBufferSequenceSamples(
observations=create_minibatch(self.observations, indices),
actions=create_minibatch(self.actions, indices),
old_values=create_minibatch(self.values, indices),
old_log_prob=create_minibatch(self.log_probs, indices),
advantages=create_minibatch(self.advantages, indices),
returns=returns_batch,
mask=masks_batch,
)


class RecurrentSequenceDictRolloutBuffer(RecurrentDictRolloutBuffer):
"""
Sequence Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Overrides the DictRecurrentRolloutBuffer to yield 3d batches of whole sequences

:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param hidden_state_shape: Shape of the buffer that will collect lstm states
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
hidden_state_shape: Tuple[int, int, int, int],
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.hidden_state_shape = hidden_state_shape
self.seq_start_indices, self.seq_end_indices = None, None
super().__init__(
buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs
)

def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSequenceSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"
# Prepare the data
if not self.generator_ready:
self.episode_starts[0, :] = 1
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)

for tensor in [
"actions",
"values",
"log_probs",
"advantages",
"returns",
"episode_starts",
]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])

self.episode_start_indices = np.where(self.episode_starts == 1)[0]
self.generator_ready = True

random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)))
# drop last batch to prevent extremely small batches causing spurious updates
batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True)
# add a dummy index to make the code below simpler
episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])])

create_minibatch = create_sequence_slicer(episode_start_indices, self.device)

# yields batches of whole sequences, shape: (sequence_length, batch_size=n_seq, features_size)
for indices in batch_sampler:
obs_batch = {}
for key in self.observations:
obs_batch[key] = create_minibatch(self.observations[key], indices)
returns_batch = create_minibatch(self.returns, indices)
masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)])

yield RecurrentDictRolloutBufferSequenceSamples(
observations=obs_batch,
actions=create_minibatch(self.actions, indices),
old_values=create_minibatch(self.values, indices),
old_log_prob=create_minibatch(self.log_probs, indices),
advantages=create_minibatch(self.advantages, indices),
returns=returns_batch,
mask=masks_batch,
)
48 changes: 48 additions & 0 deletions sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,54 @@ def evaluate_actions(
values = self.value_net(latent_vf)
return values, log_prob, distribution.entropy()

def evaluate_actions_whole_sequence(
self,
obs: th.Tensor,
actions: th.Tensor,
) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions of batches of whole sequences according to the current policy,
given the observations.

:param obs: Observation.
:param actions:
:param lstm_states: The last hidden and memory states for the LSTM.
:param episode_starts: Whether the observations correspond to new episodes
or not (we reset the lstm states in that case).
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed

# temporary fix to disable the flattening that stable_baselines3 feature extractors do by default
# flattening will turn the sequences in the batch into 1 long sequence without proper resetting of lstm hidden states
if self.features_extractor_class == FlattenExtractor:
features = obs
else:
features = self.extract_features(obs)
latent_pi, _ = self.lstm_actor(features)

if self.lstm_critic is not None:
latent_vf, _ = self.lstm_critic(features)
elif self.shared_lstm:
latent_vf = latent_pi.detach()
else:
latent_vf = self.critic(features)

latent_pi = self.mlp_extractor.forward_actor(latent_pi)
latent_vf = self.mlp_extractor.forward_critic(latent_vf)

values = self.value_net(latent_vf)

distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.distribution.log_prob(actions).sum(dim=-1)
log_prob = log_prob.reshape((*log_prob.shape, 1))

entropy = distribution.distribution.entropy().sum(dim=-1)
entropy = entropy.reshape((*entropy.shape, 1))

return values, log_prob, entropy

def _predict(
self,
observation: th.Tensor,
Expand Down
20 changes: 20 additions & 0 deletions sb3_contrib/common/recurrent/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,23 @@ class RecurrentDictRolloutBufferSamples(NamedTuple):
lstm_states: RNNStates
episode_starts: th.Tensor
mask: th.Tensor


class RecurrentRolloutBufferSequenceSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
mask: th.Tensor


class RecurrentDictRolloutBufferSequenceSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
mask: th.Tensor
Loading
Loading