Skip to content

Commit

Permalink
Add type hint to PPO like in other policies to be consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
paolodelia99 committed Apr 2, 2024
1 parent 094e635 commit 94692a1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class Critic(nn.Module):
n_units: int = 256
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -37,7 +37,7 @@ class Actor(nn.Module):
action_dim: int
n_units: int = 256
log_std_init: float = 0.0
activation_fn: Callable = nn.tanh
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
num_discrete_choices: Optional[Union[int, Sequence[int]]] = None
# For MultiDiscrete
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
ortho_init: bool = False,
log_std_init: float = 0.0,
activation_fn=nn.tanh,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh,
use_sde: bool = False,
# Note: most gSDE parameters are not used
# this is to keep API consistent with SB3
Expand Down

0 comments on commit 94692a1

Please sign in to comment.