Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/rainbow #50

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
2 changes: 2 additions & 0 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sbx.crossq import CrossQ
from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.per_dqn import PERDQN
from sbx.ppo import PPO
from sbx.sac import SAC
from sbx.td3 import TD3
Expand All @@ -26,6 +27,7 @@ def DroQ(*args, **kwargs):
"CrossQ",
"DDPG",
"DQN",
"PERDQN",
"PPO",
"SAC",
"TD3",
Expand Down
395 changes: 395 additions & 0 deletions sbx/common/prioritized_replay_buffer.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sbx/common/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import NamedTuple, Optional, Union

import flax
import numpy as np
Expand All @@ -19,3 +19,5 @@ class ReplayBufferSamplesNp(NamedTuple):
next_observations: np.ndarray
dones: np.ndarray
rewards: np.ndarray
weights: Union[np.ndarray, float] = 1.0
leaf_nodes_indices: Optional[np.ndarray] = None
5 changes: 4 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss),
) = self._train(
self.gamma,
self.target_entropy,
Expand All @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.ent_coef_state,
self.key,
)
ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params})
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
if isinstance(self.ent_coef, EntropyCoef):
self.logger.record("train/ent_coef_loss", ent_coef_loss.item())

@staticmethod
@jax.jit
Expand Down
14 changes: 11 additions & 3 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpy as np
import optax
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn

Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(
# max_grad_norm: float = 10,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
Expand All @@ -59,6 +62,8 @@ def __init__(
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
Expand Down Expand Up @@ -130,11 +135,12 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
data.observations.numpy(),
# Convert to int64
data.actions.long().numpy(),
Expand Down Expand Up @@ -222,7 +228,9 @@ def _on_step(self) -> None:
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
self.policy.qf_state = DQN.soft_update(self.tau, self.policy.qf_state)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
Expand Down
20 changes: 14 additions & 6 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
max_grad_norm: float = 10.0,
):
super().__init__(
observation_space,
Expand All @@ -85,6 +86,7 @@ def __init__(
else:
self.n_units = 256
self.activation_fn = activation_fn
self.max_grad_norm = max_grad_norm

def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)
Expand All @@ -101,9 +103,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
tx=optax.chain(
optax.clip_by_global_norm(self.max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
),
)

Expand Down Expand Up @@ -140,9 +145,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
tx=optax.chain(
optax.clip_by_global_norm(self.max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
),
)
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]
Expand Down
3 changes: 3 additions & 0 deletions sbx/per_dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbx.per_dqn.per_dqn import PERDQN

__all__ = ["PERDQN"]
Loading
Loading