Skip to content

Commit

Permalink
Added suggested adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
paolodelia99 committed Apr 2, 2024
1 parent 94692a1 commit ca8bad1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
.coverage
.coverage.*
__pycache__/
**/.mypy_cache/
**/.ruff_cache/
_build/
*.npz
*.pth
Expand Down
3 changes: 1 addition & 2 deletions sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_probability
import tensorflow_probability.substrates.jax as tfp
from flax.training.train_state import TrainState
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule
Expand All @@ -14,7 +14,6 @@
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.type_aliases import RLTrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions


Expand Down
4 changes: 2 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_sac_td3(model_class) -> None:

@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG])
def test_sac_td3_policy_kwargs(model_class) -> None:
policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=dict(pi=[64, 64], qf=[64, 64]))
policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=dict(pi=[8], qf=[8]))

model = model_class(
"MlpPolicy", "Pendulum-v1", verbose=1, gradient_steps=1, learning_rate=1e-3, policy_kwargs=policy_kwargs
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_dqn() -> None:


def test_dqn_policy_kwargs() -> None:
policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=[128, 128])
policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=[8])

model = DQN(
"MlpPolicy",
Expand Down

0 comments on commit ca8bad1

Please sign in to comment.