Skip to content

Commit

Permalink
Add actor for CrossQ
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 12, 2024
1 parent 6e533e0 commit 6af3a25
Showing 1 changed file with 56 additions and 11 deletions.
67 changes: 56 additions & 11 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class SimbaCritic(nn.Module):
def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
x = BatchRenorm(
norm_layer = partial(
BatchRenorm,
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

norm_layer = partial(BatchRenorm, use_running_average=not train, momentum=self.batch_norm_momentum)
)
x = norm_layer()(x)
x = nn.Dense(self.net_arch[0])(x)

for n_units in self.net_arch:
Expand All @@ -90,11 +90,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) ->
# TODO: double check where to put the dropout
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)
x = norm_layer()(x)
x = nn.Dense(1)(x)
return x

Expand Down Expand Up @@ -169,6 +165,51 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
return q_values


class SimbaActor(nn.Module):
net_arch: Sequence[int]
action_dim: int
log_std_min: float = -20
log_std_max: float = 2
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

def get_std(self):
# Make it work with gSDE
return jnp.array(0.0)

@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
norm_layer = partial(
BatchRenorm,
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)
x = norm_layer()(x)
x = nn.Dense(self.net_arch[0])(x)

for n_units in self.net_arch:
x = SimbaResidualBlock(
n_units,
self.activation_fn,
self.scale_factor,
norm_layer, # type: ignore[arg-type]
)(x)
x = norm_layer()(x)

mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
dist = TanhTransformedDistribution(
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
)
return dist


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
Expand Down Expand Up @@ -200,7 +241,11 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: #
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)
if self.use_batch_norm:
x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
Expand Down Expand Up @@ -431,7 +476,7 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = Actor, # TODO: replace with Simba actor
actor_class: Type[nn.Module] = SimbaActor, # TODO: replace with Simba actor
vector_critic_class: Type[nn.Module] = SimbaVectorCritic,
):
super().__init__(
Expand Down

0 comments on commit 6af3a25

Please sign in to comment.