From 655f4a3b953b3bed27982b2dcf4ed6695d7d37bc Mon Sep 17 00:00:00 2001 From: Paolo Date: Wed, 3 Apr 2024 11:49:01 +0200 Subject: [PATCH] Allow to pass custom activation function in `policy_kwargs` (#41) * Added act_fn optionality to td3 and sac * Updated makefile to new ruff version * Added test for policy_kwargs argument * Able to pass activation_fn to SAC model * Added the possibility to pass act_fn also for dqn * Reformatted dqn/policies.py according to black style * Add type hint to PPO like in other policies to be consistent * Added suggested adjustment * Update tests and version * Add support for TQC --------- Co-authored-by: Antonin Raffin --- Makefile | 8 +++--- sbx/common/policies.py | 51 ++++++++++++++++++++++++++++++++++++- sbx/dqn/policies.py | 13 +++++++--- sbx/ppo/policies.py | 6 ++--- sbx/sac/policies.py | 56 +++++----------------------------------- sbx/td3/policies.py | 58 +++++------------------------------------- sbx/tqc/policies.py | 11 +++++--- sbx/version.txt | 2 +- tests/test_run.py | 22 ++++++++++++++-- 9 files changed, 110 insertions(+), 117 deletions(-) diff --git a/Makefile b/Makefile index 2240fdc..0177d5a 100644 --- a/Makefile +++ b/Makefile @@ -12,19 +12,19 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero format: # Sort imports - ruff --select I ${LINT_PATHS} --fix + ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - ruff --select I ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/sbx/common/policies.py b/sbx/common/policies.py index dced3b5..4b6df5d 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -1,5 +1,5 @@ # import copy -from typing import Dict, Optional, Tuple, Union, no_type_check +from typing import Callable, Dict, Optional, Sequence, Tuple, Union, no_type_check import flax.linen as nn import jax @@ -120,3 +120,52 @@ def set_training_mode(self, mode: bool) -> None: # self.actor.set_training_mode(mode) # self.critic.set_training_mode(mode) self.training = mode + + +class ContinuousCritic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = Flatten()(x) + x = jnp.concatenate([x, action], -1) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = self.activation_fn(x) + x = nn.Dense(1)(x) + return x + + +class VectorCritic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + n_critics: int = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + ContinuousCritic, + variable_axes={"params": 0}, # parameters not shared between the critics + split_rngs={"params": True, "dropout": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + use_layer_norm=self.use_layer_norm, + dropout_rate=self.dropout_rate, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + )(obs, action) + return q_values diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index 1a4ed46..b9226ee 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -15,14 +15,15 @@ class QNetwork(nn.Module): n_actions: int n_units: int = 256 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = Flatten()(x) x = nn.Dense(self.n_units)(x) - x = nn.relu(x) + x = self.activation_fn(x) x = nn.Dense(self.n_units)(x) - x = nn.relu(x) + x = self.activation_fn(x) x = nn.Dense(self.n_actions)(x) return x @@ -36,6 +37,7 @@ def __init__( action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, features_extractor_class=None, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, @@ -56,13 +58,18 @@ def __init__( self.n_units = net_arch[0] else: self.n_units = 256 + self.activation_fn = activation_fn def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: key, qf_key = jax.random.split(key, 2) obs = jnp.array([self.observation_space.sample()]) - self.qf = QNetwork(n_actions=int(self.action_space.n), n_units=self.n_units) + self.qf = QNetwork( + n_actions=int(self.action_space.n), + n_units=self.n_units, + activation_fn=self.activation_fn, + ) self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 0f8d3a3..54915c8 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -20,7 +20,7 @@ class Critic(nn.Module): n_units: int = 256 - activation_fn: Callable = nn.tanh + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: @@ -37,7 +37,7 @@ class Actor(nn.Module): action_dim: int n_units: int = 256 log_std_init: float = 0.0 - activation_fn: Callable = nn.tanh + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh # For Discrete, MultiDiscrete and MultiBinary actions num_discrete_choices: Optional[Union[int, Sequence[int]]] = None # For MultiDiscrete @@ -104,7 +104,7 @@ def __init__( net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, ortho_init: bool = False, log_std_init: float = 0.0, - activation_fn=nn.tanh, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index b637aa0..f936f91 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -11,63 +11,18 @@ from stable_baselines3.common.type_aliases import Schedule from sbx.common.distributions import TanhTransformedDistribution -from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic from sbx.common.type_aliases import RLTrainState tfd = tfp.distributions -class Critic(nn.Module): - net_arch: Sequence[int] - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - - @nn.compact - def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: - x = Flatten()(x) - x = jnp.concatenate([x, action], -1) - for n_units in self.net_arch: - x = nn.Dense(n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) - x = nn.Dense(1)(x) - return x - - -class VectorCritic(nn.Module): - net_arch: Sequence[int] - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - n_critics: int = 2 - - @nn.compact - def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): - # Idea taken from https://github.com/perrin-isir/xpag - # Similar to https://github.com/tinkoff-ai/CORL for PyTorch - vmap_critic = nn.vmap( - Critic, - variable_axes={"params": 0}, # parameters not shared between the critics - split_rngs={"params": True, "dropout": True}, # different initializations - in_axes=None, - out_axes=0, - axis_size=self.n_critics, - ) - q_values = vmap_critic( - use_layer_norm=self.use_layer_norm, - dropout_rate=self.dropout_rate, - net_arch=self.net_arch, - )(obs, action) - return q_values - - class Actor(nn.Module): net_arch: Sequence[int] action_dim: int log_std_min: float = -20 log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu def get_std(self): # Make it work with gSDE @@ -78,7 +33,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) - x = nn.relu(x) + x = self.activation_fn(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) @@ -99,7 +54,7 @@ def __init__( net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, - # activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 @@ -135,6 +90,7 @@ def __init__( self.net_arch_pi = self.net_arch_qf = [256, 256] self.n_critics = n_critics self.use_sde = use_sde + self.activation_fn = activation_fn self.key = self.noise_key = jax.random.PRNGKey(0) @@ -154,6 +110,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) self.actor = Actor( action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, + activation_fn=self.activation_fn, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -172,6 +129,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) use_layer_norm=self.layer_norm, net_arch=self.net_arch_qf, n_critics=self.n_critics, + activation_fn=self.activation_fn, ) self.qf_state = RLTrainState.create( diff --git a/sbx/td3/policies.py b/sbx/td3/policies.py index e27709b..3592535 100644 --- a/sbx/td3/policies.py +++ b/sbx/td3/policies.py @@ -8,66 +8,21 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule -from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic from sbx.common.type_aliases import RLTrainState -class Critic(nn.Module): - net_arch: Sequence[int] - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - - @nn.compact - def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: - x = Flatten()(x) - x = jnp.concatenate([x, action], -1) - for n_units in self.net_arch: - x = nn.Dense(n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) - x = nn.Dense(1)(x) - return x - - -class VectorCritic(nn.Module): - net_arch: Sequence[int] - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - n_critics: int = 2 - - @nn.compact - def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): - # Idea taken from https://github.com/perrin-isir/xpag - # Similar to https://github.com/tinkoff-ai/CORL for PyTorch - vmap_critic = nn.vmap( - Critic, - variable_axes={"params": 0}, # parameters not shared between the critics - split_rngs={"params": True, "dropout": True}, # different initializations - in_axes=None, - out_axes=0, - axis_size=self.n_critics, - ) - q_values = vmap_critic( - use_layer_norm=self.use_layer_norm, - dropout_rate=self.dropout_rate, - net_arch=self.net_arch, - )(obs, action) - return q_values - - class Actor(nn.Module): net_arch: Sequence[int] action_dim: int + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore[name-defined] x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) - x = nn.relu(x) + x = self.activation_fn(x) return nn.tanh(nn.Dense(self.action_dim)(x)) @@ -82,7 +37,7 @@ def __init__( net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, - # activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, features_extractor_class=None, features_extractor_kwargs: Optional[Dict[str, Any]] = None, @@ -112,6 +67,7 @@ def __init__( else: self.net_arch_pi = self.net_arch_qf = [256, 256] self.n_critics = n_critics + self.activation_fn = activation_fn self.key = self.noise_key = jax.random.PRNGKey(0) @@ -127,8 +83,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action = jnp.array([self.action_space.sample()]) self.actor = Actor( - action_dim=int(np.prod(self.action_space.shape)), - net_arch=self.net_arch_pi, + action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn ) self.actor_state = RLTrainState.create( @@ -146,6 +101,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) use_layer_norm=self.layer_norm, net_arch=self.net_arch_qf, n_critics=self.n_critics, + activation_fn=self.activation_fn, ) self.qf_state = RLTrainState.create( diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 66c6fb1..d075de5 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -22,6 +22,7 @@ class Critic(nn.Module): use_layer_norm: bool = False dropout_rate: Optional[float] = None n_quantiles: int = 25 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jnp.ndarray: @@ -33,7 +34,7 @@ def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jn x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) if self.use_layer_norm: x = nn.LayerNorm()(x) - x = nn.relu(x) + x = self.activation_fn(x) x = nn.Dense(self.n_quantiles)(x) return x @@ -43,6 +44,7 @@ class Actor(nn.Module): action_dim: int log_std_min: float = -20 log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu def get_std(self): # Make it work with gSDE @@ -53,7 +55,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) - x = nn.relu(x) + x = self.activation_fn(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) @@ -76,7 +78,7 @@ def __init__( layer_norm: bool = False, top_quantiles_to_drop_per_net: int = 2, n_quantiles: int = 25, - # activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 @@ -118,6 +120,7 @@ def __init__( top_quantiles_to_drop_per_net = self.top_quantiles_to_drop_per_net self.n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * self.n_critics self.use_sde = use_sde + self.activation_fn = activation_fn self.key = self.noise_key = jax.random.PRNGKey(0) @@ -137,6 +140,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) self.actor = Actor( action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, + activation_fn=self.activation_fn, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -155,6 +159,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) use_layer_norm=self.layer_norm, net_arch=self.net_arch_qf, n_quantiles=self.n_quantiles, + activation_fn=self.activation_fn, ) self.qf1_state = RLTrainState.create( diff --git a/sbx/version.txt b/sbx/version.txt index ac454c6..54d1a4f 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.12.0 +0.13.0 diff --git a/tests/test_run.py b/tests/test_run.py index 010d017..ddbb2b1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,5 +1,6 @@ from typing import Optional, Type +import flax.linen as nn import numpy as np import pytest from stable_baselines3 import HerReplayBuffer @@ -72,16 +73,33 @@ def test_sac_td3(model_class) -> None: model.learn(110) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) +def test_policy_kwargs(model_class) -> None: + env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1" + + model = model_class( + "MlpPolicy", + env_id, + verbose=1, + gradient_steps=1, + learning_rate=1e-3, + policy_kwargs=dict(activation_fn=nn.leaky_relu, net_arch=[8]), + ) + model.learn(110) + + @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) def test_ppo(env_id: str) -> None: model = PPO( "MlpPolicy", env_id, verbose=1, - n_steps=64, + n_steps=32, + batch_size=32, n_epochs=2, + policy_kwargs=dict(activation_fn=nn.leaky_relu), ) - model.learn(128, progress_bar=True) + model.learn(64, progress_bar=True) def test_dqn() -> None: