Skip to content

Commit

Permalink
Fix import, argument and add run test
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 10, 2024
1 parent bf88774 commit 2bb4cca
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
3 changes: 1 addition & 2 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os

from sbx.bro import BRO
from sbx.crossq import CrossQ
from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.ppo import PPO
from sbx.sac import SAC
from sbx.td3 import TD3
from sbx.tqc import TQC
from sbx.bro import BRO


# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand Down
4 changes: 2 additions & 2 deletions sbx/bro/bro.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule

from sbx.bro.policies import BROPolicy
from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
from sbx.bro.policies import BROPolicy


class EntropyCoef(nn.Module):
Expand Down Expand Up @@ -528,7 +528,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size)

(qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond(
distributional == True,
distributional,
# If True:
cls.update_critic_quantile,
# If False:
Expand Down
8 changes: 5 additions & 3 deletions sbx/bro/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class BroNet(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(self.net_arch[0], self.activation_fn)(x)
x = nn.Dense(self.net_arch[0])(x)
x = nn.LayerNorm()(x)
x = self.activation_fn(x)
for n_units in self.net_arch:
x = BroNetBlock(n_units)(x)
x = BroNetBlock(n_units, self.activation_fn)(x)
return x


Expand Down Expand Up @@ -168,7 +168,9 @@ def __init__(
self.net_arch_qf = net_arch["qf"]
else:
self.net_arch_pi = [256]
# In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR
# In the original implementation, the authors use [512, 512]
# but with a higher replay ratio (RR),
# here we use bigger network size to compensate for the smaller RR
self.net_arch_qf = [1024, 1024]
self.n_critics = n_critics
self.use_sde = use_sde
Expand Down
4 changes: 2 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.evaluation import evaluate_policy

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ, BRO


def check_save_load(model, model_class, tmp_path):
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_dropout(model_class):
model.learn(110)


@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ])
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ, BRO])
def test_policy_kwargs(model_class) -> None:
env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"

Expand Down

0 comments on commit 2bb4cca

Please sign in to comment.