From 6af3a25156a6736a327648f122becb01f2d962fa Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Nov 2024 09:23:18 +0100 Subject: [PATCH] Add actor for CrossQ --- sbx/crossq/policies.py | 67 +++++++++++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 700db50..674a830 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -71,13 +71,13 @@ class SimbaCritic(nn.Module): def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray: x = Flatten()(x) x = jnp.concatenate([x, action], -1) - x = BatchRenorm( + norm_layer = partial( + BatchRenorm, use_running_average=not train, momentum=self.batch_norm_momentum, warmup_steps=self.renorm_warmup_steps, - )(x) - - norm_layer = partial(BatchRenorm, use_running_average=not train, momentum=self.batch_norm_momentum) + ) + x = norm_layer()(x) x = nn.Dense(self.net_arch[0])(x) for n_units in self.net_arch: @@ -90,11 +90,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> # TODO: double check where to put the dropout if self.dropout_rate is not None and self.dropout_rate > 0: x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - x = BatchRenorm( - use_running_average=not train, - momentum=self.batch_norm_momentum, - warmup_steps=self.renorm_warmup_steps, - )(x) + x = norm_layer()(x) x = nn.Dense(1)(x) return x @@ -169,6 +165,51 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False): return q_values +class SimbaActor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -20 + log_std_max: float = 2 + use_batch_norm: bool = True + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + norm_layer = partial( + BatchRenorm, + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + ) + x = norm_layer()(x) + x = nn.Dense(self.net_arch[0])(x) + + for n_units in self.net_arch: + x = SimbaResidualBlock( + n_units, + self.activation_fn, + self.scale_factor, + norm_layer, # type: ignore[arg-type] + )(x) + x = norm_layer()(x) + + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + class Actor(nn.Module): net_arch: Sequence[int] action_dim: int @@ -200,7 +241,11 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # x = nn.Dense(n_units)(x) x = self.activation_fn(x) if self.use_batch_norm: - x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) @@ -431,7 +476,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, - actor_class: Type[nn.Module] = Actor, # TODO: replace with Simba actor + actor_class: Type[nn.Module] = SimbaActor, # TODO: replace with Simba actor vector_critic_class: Type[nn.Module] = SimbaVectorCritic, ): super().__init__(