From b5ce09105aa8df2701e85d30f0584380b2ee19f8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 09:21:33 +0200 Subject: [PATCH] Add multi-env support --- sbx/common/prioritized_replay_buffer.py | 89 +++++++++++-------------- sbx/per_dqn/per_dqn.py | 36 +++++++--- 2 files changed, 64 insertions(+), 61 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index 05de787..6216aa3 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -14,7 +14,6 @@ from gymnasium import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples -from stable_baselines3.common.utils import get_linear_fn from stable_baselines3.common.vec_env.vec_normalize import VecNormalize @@ -41,6 +40,8 @@ def __init__(self, capacity: int, reduce_op: Callable, neutral_element: float) - """ assert capacity > 0 and capacity & (capacity - 1) == 0, f"Capacity must be positive and a power of 2, not {capacity}" self._capacity = capacity + # First index is the root, leaf nodes are in [capacity, 2 * capacity - 1]. + # For each parent node i, left child has index [2 * i], right child [2 * i + 1] self._values = np.full(2 * capacity, neutral_element) self._reduce_op = reduce_op self.neutral_element = neutral_element @@ -97,8 +98,10 @@ def __setitem__(self, idx: np.ndarray, val: np.ndarray) -> None: :param idx: index of the value to be updated :param val: new value """ + # assert np.all(0 <= idx < self._capacity), f"Trying to set item outside capacity: {idx}" # Indices of the leafs indices = idx + self._capacity + # Update the leaf nodes and then the related nodes self._values[indices] = val if isinstance(indices, int): indices = np.array([indices]) @@ -153,8 +156,7 @@ def find_prefixsum_idx(self, prefixsum: np.ndarray) -> np.ndarray: if isinstance(prefixsum, float): prefixsum = np.array([prefixsum]) assert 0 <= np.min(prefixsum) - assert np.max(prefixsum) <= self.sum() + 1e-5 - assert isinstance(prefixsum[0], float) + # assert np.max(prefixsum) <= self.sum() + 1e-5 indices = np.ones(len(prefixsum), dtype=int) should_continue = np.ones(len(prefixsum), dtype=bool) @@ -227,8 +229,6 @@ def __init__( device: Union[th.device, str] = "auto", n_envs: int = 1, alpha: float = 0.5, - beta: float = 0.4, - final_beta: float = 1.0, optimize_memory_usage: bool = False, min_priority: float = 1e-6, ): @@ -238,7 +238,7 @@ def __init__( assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" # TODO: add support for multi env - assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" + # assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" # Find the next power of 2 for the buffer size power_of_two = int(np.ceil(np.log2(buffer_size))) @@ -249,26 +249,12 @@ def __init__( self._alpha = alpha - # Track the training progress remaining (from 1 to 0) - # this is used to update beta - self._current_progress_remaining = 1.0 - - # TODO: move beta schedule to the DQN algorithm - self._inital_beta = beta - self._final_beta = final_beta - self.beta_schedule = get_linear_fn( - self._inital_beta, - self._final_beta, - end_fraction=1.0, - ) - self._sum_tree = SumSegmentTree(tree_capacity) self._min_tree = MinSegmentTree(tree_capacity) - - @property - def beta(self) -> float: - # Linear schedule - return self.beta_schedule(self._current_progress_remaining) + # Flatten the indices from the buffer to store them in the sum tree + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + self.env_offsets = np.arange(self.n_envs) def add( self, @@ -289,14 +275,14 @@ def add( :param done: Whether the episode was finished after the transition to be stored. :param infos: Eventual information given by the environment. """ - # store transition index with maximum priority in sum tree - self._sum_tree[self.pos] = self._max_priority**self._alpha - self._min_tree[self.pos] = self._max_priority**self._alpha + # Store transition index with maximum priority in sum tree + self._sum_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha + self._min_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha - # store transition in the buffer + # Store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) - def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: """ Sample elements from the prioritized replay buffer. @@ -305,20 +291,24 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB to normalize the observations/rewards when sampling :return: a batch of sampled experiences from the buffer. """ - assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + # assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." # priorities = np.zeros((batch_size, 1)) # sample_indices = np.zeros(batch_size, dtype=np.uint32) # TODO: check how things are sampled in the original implementation - sample_indices = self._sample_proportional(batch_size) - leaf_nodes_indices = sample_indices + leaf_nodes_indices = self._sample_proportional(batch_size) + # Convert the leaf nodes indices to buffer indices + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + buffer_indices = leaf_nodes_indices // self.n_envs + env_indices = leaf_nodes_indices % self.n_envs # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. # probs = priorities / self.tree.total_sum - probabilities = self._sum_tree[sample_indices] / self._sum_tree.sum() + probabilities = self._sum_tree[leaf_nodes_indices] / self._sum_tree.sum() # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. @@ -326,20 +316,21 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB # min_probability = self._min_tree.min() / self._sum_tree.sum() # max_weight = (min_probability * self.size()) ** (-self.beta) # weights = (probabilities * self.size()) ** (-self.beta) / max_weight - weights = (probabilities * self.size()) ** (-self.beta) + weights = (probabilities * self.size()) ** (-beta) weights = weights / weights.max() - # TODO: add proper support for multi env # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) - env_indices = np.zeros(batch_size, dtype=np.uint32) - next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) + # env_indices = np.zeros(batch_size, dtype=np.uint32) + next_obs = self._normalize_obs(self.next_observations[buffer_indices, env_indices, :], env) batch = ( - self._normalize_obs(self.observations[sample_indices, env_indices, :], env), - self.actions[sample_indices, env_indices, :], + self._normalize_obs(self.observations[buffer_indices, env_indices, :], env), + self.actions[buffer_indices, env_indices, :], next_obs, - self.dones[sample_indices], - self.rewards[sample_indices], + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + (self.dones[buffer_indices, env_indices] * (1 - self.timeouts[buffer_indices, env_indices])).reshape(-1, 1), + self._normalize_reward(self.rewards[buffer_indices, env_indices].reshape(-1, 1), env), weights, ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] @@ -358,28 +349,24 @@ def _sample_proportional(self, batch_size: int) -> np.ndarray: return self._sum_tree.find_prefixsum_idx(priorities_sum) # def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None: - def update_priorities(self, indices: np.ndarray, priorities: np.ndarray, progress_remaining: float) -> None: + def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: """ Update priorities of sampled transitions. :param leaf_nodes_indices: Indices of the sampled transitions. - :param td_errors: New priorities, td error in the case of + :param priorities: New priorities, td error in the case of proportional prioritized replay buffer. """ - # TODO: move beta to the DQN algorithm - # Update beta schedule - self._current_progress_remaining = progress_remaining - # TODO: double check that all samples are updated # priorities = np.abs(td_errors) + self.min_priority priorities += self._min_priority # assert len(indices) == len(priorities) assert np.min(priorities) > 0 - assert np.min(indices) >= 0 - assert np.max(indices) < self.buffer_size + assert np.min(leaf_nodes_indices) >= 0 + assert np.max(leaf_nodes_indices) < self.buffer_size # TODO: check if we need to add the min_priority here # priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha - self._sum_tree[indices] = priorities**self._alpha - self._min_tree[indices] = priorities**self._alpha + self._sum_tree[leaf_nodes_indices] = priorities**self._alpha + self._min_tree[leaf_nodes_indices] = priorities**self._alpha # Update max priority for new samples self._max_priority = max(self._max_priority, np.max(priorities)) diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index 7f5ae9a..9aa63a4 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -5,6 +5,7 @@ import numpy as np import optax from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn from sbx.common.prioritized_replay_buffer import PrioritizedReplayBuffer from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState @@ -39,6 +40,8 @@ def __init__( exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, + initial_beta: float = 0.4, + final_beta: float = 1.0, optimize_memory_usage: bool = False, # Note: unused but to match SB3 API # max_grad_norm: float = 10, train_freq: Union[int, Tuple[int, str]] = 4, @@ -77,6 +80,19 @@ def __init__( _init_setup_model=_init_setup_model, ) + self._inital_beta = initial_beta + self._final_beta = final_beta + self.beta_schedule = get_linear_fn( + self._inital_beta, + self._final_beta, + end_fraction=1.0, + ) + + @property + def beta(self) -> float: + # Linear schedule + return self.beta_schedule(self._current_progress_remaining) + def learn( self, total_timesteps: int, @@ -97,7 +113,7 @@ def learn( def train(self, batch_size: int, gradient_steps: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) - data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env) # Convert to numpy data = ReplayBufferSamplesNp( data.observations.numpy(), @@ -121,7 +137,7 @@ def train(self, batch_size: int, gradient_steps: int) -> None: "info": { "critic_loss": jnp.array([0.0]), "qf_mean_value": jnp.array([0.0]), - "td_error": jnp.zeros_like(data.rewards), + "priorities": jnp.zeros_like(data.rewards), }, } @@ -137,12 +153,12 @@ def train(self, batch_size: int, gradient_steps: int) -> None: self.policy.qf_state = update_carry["qf_state"] qf_loss_value = update_carry["info"]["critic_loss"] qf_mean_value = update_carry["info"]["qf_mean_value"] / gradient_steps - td_error = update_carry["info"]["td_error"] + priorities = update_carry["info"]["priorities"] # Update priorities, they will be proportional to the td error # Note: compared to the original implementation, we update # the priorities after all the gradient steps - self.replay_buffer.update_priorities(data.leaf_nodes_indices, td_error, self._current_progress_remaining) + self.replay_buffer.update_priorities(data.leaf_nodes_indices, priorities) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") @@ -179,24 +195,24 @@ def weighted_huber_loss(params): # Retrieve the q-values for the actions from the replay buffer current_q_values = jnp.take_along_axis(current_q_values, replay_actions, axis=1) # TD error in absolute value, to update priorities - td_error = jnp.abs(current_q_values - target_q_values) + priorities = jnp.abs(current_q_values - target_q_values) # Weighted Huber loss using importance sampling weights loss = (sampling_weights * optax.huber_loss(current_q_values, target_q_values)).mean() - return loss, (current_q_values.mean(), td_error.flatten()) + return loss, (current_q_values.mean(), priorities.flatten()) - (qf_loss_value, (qf_mean_value, td_error)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( + (qf_loss_value, (qf_mean_value, priorities)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( qf_state.params ) qf_state = qf_state.apply_gradients(grads=grads) - return qf_state, (qf_loss_value, qf_mean_value, td_error) + return qf_state, (qf_loss_value, qf_mean_value, priorities) @staticmethod @jax.jit def _train(carry, indices): data = carry["data"] - qf_state, (qf_loss_value, qf_mean_value, td_error) = PERDQN.update_qnetwork( + qf_state, (qf_loss_value, qf_mean_value, priorities) = PERDQN.update_qnetwork( carry["gamma"], carry["qf_state"], observations=data.observations[indices], @@ -210,6 +226,6 @@ def _train(carry, indices): carry["qf_state"] = qf_state carry["info"]["critic_loss"] += qf_loss_value carry["info"]["qf_mean_value"] += qf_mean_value - carry["info"]["td_error"] = carry["info"]["td_error"].at[indices].set(td_error) + carry["info"]["priorities"] = carry["info"]["priorities"].at[indices].set(priorities) return carry, None