From 2bb4ccae5362ef1976f158a3cc82f7bcb2d36af3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 10 Nov 2024 14:41:47 +0100 Subject: [PATCH] Fix import, argument and add run test --- sbx/__init__.py | 3 +-- sbx/bro/bro.py | 4 ++-- sbx/bro/policies.py | 8 +++++--- tests/test_run.py | 4 ++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sbx/__init__.py b/sbx/__init__.py index 8ec1325..8b47be1 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -1,5 +1,6 @@ import os +from sbx.bro import BRO from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN @@ -7,8 +8,6 @@ from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC -from sbx.bro import BRO - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index faca6da..c77afc9 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -14,9 +14,9 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from sbx.bro.policies import BROPolicy from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.bro.policies import BROPolicy class EntropyCoef(nn.Module): @@ -528,7 +528,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) (qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond( - distributional == True, + distributional, # If True: cls.update_critic_quantile, # If False: diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index f89f559..8b44b09 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -37,11 +37,11 @@ class BroNet(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = nn.Dense(self.net_arch[0], self.activation_fn)(x) + x = nn.Dense(self.net_arch[0])(x) x = nn.LayerNorm()(x) x = self.activation_fn(x) for n_units in self.net_arch: - x = BroNetBlock(n_units)(x) + x = BroNetBlock(n_units, self.activation_fn)(x) return x @@ -168,7 +168,9 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256] - # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR + # In the original implementation, the authors use [512, 512] + # but with a higher replay ratio (RR), + # here we use bigger network size to compensate for the smaller RR self.net_arch_qf = [1024, 1024] self.n_critics = n_critics self.use_sde = use_sde diff --git a/tests/test_run.py b/tests/test_run.py index 18d6dec..54b676a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ, BRO def check_save_load(model, model_class, tmp_path): @@ -116,7 +116,7 @@ def test_dropout(model_class): model.learn(110) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ, BRO]) def test_policy_kwargs(model_class) -> None: env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"