From ca8bad10e6383b9e0286ed44a845d2a00d3f31a6 Mon Sep 17 00:00:00 2001 From: paolodelia99 Date: Tue, 2 Apr 2024 18:22:47 +0200 Subject: [PATCH] Added suggested adjustment --- .gitignore | 2 -- sbx/sac/policies.py | 3 +-- tests/test_run.py | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index a700499..f3365c6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,6 @@ .coverage .coverage.* __pycache__/ -**/.mypy_cache/ -**/.ruff_cache/ _build/ *.npz *.pth diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 18e1e3f..f936f91 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import numpy as np import optax -import tensorflow_probability +import tensorflow_probability.substrates.jax as tfp from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule @@ -14,7 +14,6 @@ from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic from sbx.common.type_aliases import RLTrainState -tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions diff --git a/tests/test_run.py b/tests/test_run.py index 3b4580c..d90f0bb 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -75,7 +75,7 @@ def test_sac_td3(model_class) -> None: @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG]) def test_sac_td3_policy_kwargs(model_class) -> None: - policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=dict(pi=[64, 64], qf=[64, 64])) + policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=dict(pi=[8], qf=[8])) model = model_class( "MlpPolicy", "Pendulum-v1", verbose=1, gradient_steps=1, learning_rate=1e-3, policy_kwargs=policy_kwargs @@ -107,7 +107,7 @@ def test_dqn() -> None: def test_dqn_policy_kwargs() -> None: - policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=[128, 128]) + policy_kwargs = dict(activation_fn=nn.leaky_relu, net_arch=[8]) model = DQN( "MlpPolicy",