Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to pass custom activation function in policy_kwargs #41

Merged
merged 10 commits into from
Apr 3, 2024
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
51 changes: 50 additions & 1 deletion sbx/common/policies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# import copy
from typing import Dict, Optional, Tuple, Union, no_type_check
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, no_type_check

import flax.linen as nn
import jax
Expand Down Expand Up @@ -120,3 +120,52 @@ def set_training_mode(self, mode: bool) -> None:
# self.actor.set_training_mode(mode)
# self.critic.set_training_mode(mode)
self.training = mode


class ContinuousCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = self.activation_fn(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
ContinuousCritic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
)(obs, action)
return q_values
13 changes: 10 additions & 3 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
class QNetwork(nn.Module):
n_actions: int
n_units: int = 256
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_actions)(x)
return x

Expand All @@ -36,6 +37,7 @@ def __init__(
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
Expand All @@ -56,13 +58,18 @@ def __init__(
self.n_units = net_arch[0]
else:
self.n_units = 256
self.activation_fn = activation_fn

def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])

self.qf = QNetwork(n_actions=int(self.action_space.n), n_units=self.n_units)
self.qf = QNetwork(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
apply_fn=self.qf.apply,
Expand Down
6 changes: 3 additions & 3 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class Critic(nn.Module):
n_units: int = 256
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -37,7 +37,7 @@ class Actor(nn.Module):
action_dim: int
n_units: int = 256
log_std_init: float = 0.0
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
num_discrete_choices: Optional[Union[int, Sequence[int]]] = None
# For MultiDiscrete
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
ortho_init: bool = False,
log_std_init: float = 0.0,
activation_fn=nn.tanh,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down
56 changes: 7 additions & 49 deletions sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,63 +11,18 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.type_aliases import RLTrainState

tfd = tfp.distributions


class Critic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
Critic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
)(obs, action)
return q_values


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
log_std_min: float = -20
log_std_max: float = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

def get_std(self):
# Make it work with gSDE
Expand All @@ -78,7 +33,7 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
Expand All @@ -99,7 +54,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0.0,
layer_norm: bool = False,
# activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down Expand Up @@ -135,6 +90,7 @@ def __init__(
self.net_arch_pi = self.net_arch_qf = [256, 256]
self.n_critics = n_critics
self.use_sde = use_sde
self.activation_fn = activation_fn

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand All @@ -154,6 +110,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
self.actor = Actor(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
activation_fn=self.activation_fn,
)
# Hack to make gSDE work without modifying internal SB3 code
self.actor.reset_noise = self.reset_noise
Expand All @@ -172,6 +129,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
use_layer_norm=self.layer_norm,
net_arch=self.net_arch_qf,
n_critics=self.n_critics,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
Expand Down
58 changes: 7 additions & 51 deletions sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,21 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.type_aliases import RLTrainState


class Critic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
n_critics: int = 2

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
Critic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
use_layer_norm=self.use_layer_norm,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
)(obs, action)
return q_values


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore[name-defined]
x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
x = self.activation_fn(x)
return nn.tanh(nn.Dense(self.action_dim)(x))


Expand All @@ -82,7 +37,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0.0,
layer_norm: bool = False,
# activation_fn: Type[nn.Module] = nn.ReLU,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -112,6 +67,7 @@ def __init__(
else:
self.net_arch_pi = self.net_arch_qf = [256, 256]
self.n_critics = n_critics
self.activation_fn = activation_fn

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand All @@ -127,8 +83,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
action = jnp.array([self.action_space.sample()])

self.actor = Actor(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn
)

self.actor_state = RLTrainState.create(
Expand All @@ -146,6 +101,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
use_layer_norm=self.layer_norm,
net_arch=self.net_arch_qf,
n_critics=self.n_critics,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
Expand Down
Loading
Loading