Skip to content

Commit

Permalink
Deprecate DroQ class
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Apr 3, 2024
1 parent f1426cd commit a7b3135
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 22 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

Expand All @@ -62,11 +62,12 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github.
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
Expand All @@ -91,11 +92,12 @@ The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
Expand Down Expand Up @@ -128,7 +130,7 @@ and then using the RL Zoo script defined above: `python train.py --algo sac --en
We recommend playing with the `policy_delay` and `gradient_steps` parameters for better speed/efficiency.
Having a higher learning rate for the q-value function is also helpful: `qf_learning_rate: !!float 1e-3`.


Note: when using the DroQ configuration with CrossQ, you should set `layer_norm=False` as there is already batch normalization.

## Citing the Project

Expand Down
6 changes: 6 additions & 0 deletions sbx/droq/droq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

from stable_baselines3.common.buffers import ReplayBuffer
Expand Down Expand Up @@ -76,5 +77,10 @@ def __init__(
self.policy_kwargs["dropout_rate"] = dropout_rate
self.policy_kwargs["layer_norm"] = layer_norm

warnings.warn(
"Using DroQ class directly is deprecated and will be removed in v0.14.0 of SBX. "
"Please use SAC/TQC/CrossQ instead with the DroQ configuration, see https://github.com/araffin/sbx?tab=readme-ov-file#note-about-droq"
)

if _init_setup_model:
self._setup_model()
33 changes: 17 additions & 16 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ def check_save_load(model, model_class, tmp_path):


def test_droq(tmp_path):
model = DroQ(
"MlpPolicy",
"Pendulum-v1",
learning_starts=50,
learning_rate=1e-3,
tau=0.02,
gamma=0.98,
verbose=1,
buffer_size=5000,
gradient_steps=2,
ent_coef="auto_1.0",
seed=1,
dropout_rate=0.001,
layer_norm=True,
# action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
with pytest.warns(UserWarning, match="deprecated"):
model = DroQ(
"MlpPolicy",
"Pendulum-v1",
learning_starts=50,
learning_rate=1e-3,
tau=0.02,
gamma=0.98,
verbose=1,
buffer_size=5000,
gradient_steps=2,
ent_coef="auto_1.0",
seed=1,
dropout_rate=0.001,
layer_norm=True,
# action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
model.learn(total_timesteps=1500)
# Check that something was learned
evaluate_policy(model, model.get_env(), reward_threshold=-800)
Expand Down

0 comments on commit a7b3135

Please sign in to comment.