Skip to content

Commit

Permalink
Allow to change the number of warmup steps
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 31, 2024
1 parent 55d60e3 commit 2ef9dbe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 2 additions & 2 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class BatchRenorm(Module):
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warm_up_steps: int = 100_000
warmup_steps: int = 100_000
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
Expand Down Expand Up @@ -181,7 +181,7 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32)
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean
custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var

Expand Down
14 changes: 10 additions & 4 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Critic(nn.Module):
use_batch_norm: bool = True
dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
renorm_warm_up_steps: int = 100_000
renorm_warmup_steps: int = 100_000

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray:
Expand All @@ -33,7 +33,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) ->
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warm_up_steps=self.renorm_warm_up_steps,
warmup_steps=self.renorm_warmup_steps,
)(x)
else:
# Create dummy batchstats
Expand All @@ -58,6 +58,7 @@ class VectorCritic(nn.Module):
use_layer_norm: bool = False
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
dropout_rate: Optional[float] = None
n_critics: int = 2

Expand All @@ -77,6 +78,7 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
use_layer_norm=self.use_layer_norm,
use_batch_norm=self.use_batch_norm,
batch_norm_momentum=self.batch_norm_momentum,
renorm_warmup_steps=self.renorm_warmup_steps,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
)(obs, action, train)
Expand All @@ -90,7 +92,7 @@ class Actor(nn.Module):
log_std_max: float = 2
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warm_up_steps: int = 100_000
renorm_warmup_steps: int = 100_000

def get_std(self):
# Make it work with gSDE
Expand All @@ -103,7 +105,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: #
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warm_up_steps=self.renorm_warm_up_steps,
warmup_steps=self.renorm_warmup_steps,
)(x)
else:
# Create dummy batchstats
Expand Down Expand Up @@ -138,6 +140,7 @@ def __init__(
batch_norm: bool = True, # for critic
batch_norm_actor: bool = True,
batch_norm_momentum: float = 0.99,
renorm_warmup_steps: int = 100_000,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down Expand Up @@ -174,6 +177,7 @@ def __init__(
self.batch_norm = batch_norm
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_actor = batch_norm_actor
self.renorm_warmup_steps = renorm_warmup_steps

if net_arch is not None:
if isinstance(net_arch, list):
Expand Down Expand Up @@ -211,6 +215,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
net_arch=self.net_arch_pi,
use_batch_norm=self.batch_norm_actor,
batch_norm_momentum=self.batch_norm_momentum,
renorm_warmup_steps=self.renorm_warmup_steps,
)
# Hack to make gSDE work without modifying internal SB3 code
self.actor.reset_noise = self.reset_noise
Expand All @@ -237,6 +242,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
use_layer_norm=self.layer_norm,
use_batch_norm=self.batch_norm,
batch_norm_momentum=self.batch_norm_momentum,
renorm_warmup_steps=self.renorm_warmup_steps,
net_arch=self.net_arch_qf,
n_critics=self.n_critics,
)
Expand Down

0 comments on commit 2ef9dbe

Please sign in to comment.