Skip to content

Commit

Permalink
Allow to pass custom activation function in policy_kwargs (#41)
Browse files Browse the repository at this point in the history
* Added act_fn optionality to td3 and sac

* Updated makefile to new ruff version

* Added test for policy_kwargs argument

* Able to pass activation_fn to SAC model

* Added the possibility to pass act_fn also for dqn

* Reformatted dqn/policies.py according to black style

* Add type hint to PPO like in other policies to be consistent

* Added suggested adjustment

* Update tests and version

* Add support for TQC

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
paolodelia99 and araffin authored Apr 3, 2024
1 parent 46dcd7f commit 655f4a3
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 117 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
51 changes: 50 additions & 1 deletion sbx/common/policies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# import copy
from typing import Dict, Optional, Tuple, Union, no_type_check
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, no_type_check

import flax.linen as nn
import jax
Expand Down Expand Up @@ -120,3 +120,52 @@ def set_training_mode(self, mode: bool) -> None:
# self.actor.set_training_mode(mode)
# self.critic.set_training_mode(mode)
self.training = mode


class ContinuousCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = self.activation_fn(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
ContinuousCritic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
)(obs, action)
return q_values
13 changes: 10 additions & 3 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
class QNetwork(nn.Module):
n_actions: int
n_units: int = 256
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_actions)(x)
return x

Expand All @@ -36,6 +37,7 @@ def __init__(
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
Expand All @@ -56,13 +58,18 @@ def __init__(
self.n_units = net_arch[0]
else:
self.n_units = 256
self.activation_fn = activation_fn

def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])

self.qf = QNetwork(n_actions=int(self.action_space.n), n_units=self.n_units)
self.qf = QNetwork(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
apply_fn=self.qf.apply,
Expand Down
6 changes: 3 additions & 3 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class Critic(nn.Module):
n_units: int = 256
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -37,7 +37,7 @@ class Actor(nn.Module):
action_dim: int
n_units: int = 256
log_std_init: float = 0.0
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
num_discrete_choices: Optional[Union[int, Sequence[int]]] = None
# For MultiDiscrete
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
ortho_init: bool = False,
log_std_init: float = 0.0,
activation_fn=nn.tanh,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down
56 changes: 7 additions & 49 deletions sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,63 +11,18 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.type_aliases import RLTrainState

tfd = tfp.distributions


class Critic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
Critic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
)(obs, action)
return q_values


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
log_std_min: float = -20
log_std_max: float = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

def get_std(self):
# Make it work with gSDE
Expand All @@ -78,7 +33,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
Expand All @@ -99,7 +54,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0.0,
layer_norm: bool = False,
# activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down Expand Up @@ -135,6 +90,7 @@ def __init__(
self.net_arch_pi = self.net_arch_qf = [256, 256]
self.n_critics = n_critics
self.use_sde = use_sde
self.activation_fn = activation_fn

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand All @@ -154,6 +110,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
self.actor = Actor(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
activation_fn=self.activation_fn,
)
# Hack to make gSDE work without modifying internal SB3 code
self.actor.reset_noise = self.reset_noise
Expand All @@ -172,6 +129,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
use_layer_norm=self.layer_norm,
net_arch=self.net_arch_qf,
n_critics=self.n_critics,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
Expand Down
58 changes: 7 additions & 51 deletions sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,21 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.type_aliases import RLTrainState


class Critic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
Critic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
)(obs, action)
return q_values


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore[name-defined]
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
return nn.tanh(nn.Dense(self.action_dim)(x))


Expand All @@ -82,7 +37,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0.0,
layer_norm: bool = False,
# activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -112,6 +67,7 @@ def __init__(
else:
self.net_arch_pi = self.net_arch_qf = [256, 256]
self.n_critics = n_critics
self.activation_fn = activation_fn

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand All @@ -127,8 +83,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
action = jnp.array([self.action_space.sample()])

self.actor = Actor(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn
)

self.actor_state = RLTrainState.create(
Expand All @@ -146,6 +101,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
use_layer_norm=self.layer_norm,
net_arch=self.net_arch_qf,
n_critics=self.n_critics,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
Expand Down
Loading

0 comments on commit 655f4a3

Please sign in to comment.