From 2ef9dbe5d38953739916ecf83707695109ec6657 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 31 Mar 2024 12:36:42 +0200 Subject: [PATCH] Allow to change the number of warmup steps --- sbx/common/jax_layers.py | 4 ++-- sbx/crossq/policies.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index c3ebf4e..67e79e8 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -80,7 +80,7 @@ class BatchRenorm(Module): axis: int = -1 momentum: float = 0.99 epsilon: float = 0.001 - warm_up_steps: int = 100_000 + warmup_steps: int = 100_000 dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_bias: bool = True @@ -181,7 +181,7 @@ def __call__(self, x, use_running_average: Optional[bool] = None): # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) # the constraints are linearly relaxed to r_max/d_max over 40k steps # Here we only have a warmup phase - is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32) + is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32) custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index f9dd9ee..87924d6 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -23,7 +23,7 @@ class Critic(nn.Module): use_batch_norm: bool = True dropout_rate: Optional[float] = None batch_norm_momentum: float = 0.99 - renorm_warm_up_steps: int = 100_000 + renorm_warmup_steps: int = 100_000 @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray: @@ -33,7 +33,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> x = BatchRenorm( use_running_average=not train, momentum=self.batch_norm_momentum, - warm_up_steps=self.renorm_warm_up_steps, + warmup_steps=self.renorm_warmup_steps, )(x) else: # Create dummy batchstats @@ -58,6 +58,7 @@ class VectorCritic(nn.Module): use_layer_norm: bool = False use_batch_norm: bool = True batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 dropout_rate: Optional[float] = None n_critics: int = 2 @@ -77,6 +78,7 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False): use_layer_norm=self.use_layer_norm, use_batch_norm=self.use_batch_norm, batch_norm_momentum=self.batch_norm_momentum, + renorm_warmup_steps=self.renorm_warmup_steps, dropout_rate=self.dropout_rate, net_arch=self.net_arch, )(obs, action, train) @@ -90,7 +92,7 @@ class Actor(nn.Module): log_std_max: float = 2 use_batch_norm: bool = True batch_norm_momentum: float = 0.99 - renorm_warm_up_steps: int = 100_000 + renorm_warmup_steps: int = 100_000 def get_std(self): # Make it work with gSDE @@ -103,7 +105,7 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # x = BatchRenorm( use_running_average=not train, momentum=self.batch_norm_momentum, - warm_up_steps=self.renorm_warm_up_steps, + warmup_steps=self.renorm_warmup_steps, )(x) else: # Create dummy batchstats @@ -138,6 +140,7 @@ def __init__( batch_norm: bool = True, # for critic batch_norm_actor: bool = True, batch_norm_momentum: float = 0.99, + renorm_warmup_steps: int = 100_000, use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 @@ -174,6 +177,7 @@ def __init__( self.batch_norm = batch_norm self.batch_norm_momentum = batch_norm_momentum self.batch_norm_actor = batch_norm_actor + self.renorm_warmup_steps = renorm_warmup_steps if net_arch is not None: if isinstance(net_arch, list): @@ -211,6 +215,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) net_arch=self.net_arch_pi, use_batch_norm=self.batch_norm_actor, batch_norm_momentum=self.batch_norm_momentum, + renorm_warmup_steps=self.renorm_warmup_steps, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -237,6 +242,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) use_layer_norm=self.layer_norm, use_batch_norm=self.batch_norm, batch_norm_momentum=self.batch_norm_momentum, + renorm_warmup_steps=self.renorm_warmup_steps, net_arch=self.net_arch_qf, n_critics=self.n_critics, )