From 774ac18a93767fd3dff8f3abaeb028fb58b3f5f4 Mon Sep 17 00:00:00 2001 From: kplers Date: Mon, 2 Dec 2024 21:22:19 +0900 Subject: [PATCH 1/2] Add policy documentation links to policy_kwargs parameter --- docs/modules/ppo_mask.rst | 1 + docs/modules/ppo_recurrent.rst | 1 + sb3_contrib/ars/ars.py | 2 +- sb3_contrib/crossq/crossq.py | 2 +- sb3_contrib/ppo_mask/ppo_mask.py | 2 +- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 2 +- sb3_contrib/qrdqn/qrdqn.py | 2 +- sb3_contrib/tqc/tqc.py | 2 +- sb3_contrib/trpo/trpo.py | 2 +- 9 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 4ff6f06d..bf17261a 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -245,6 +245,7 @@ Parameters :members: :inherited-members: +.. _ppo_mask_policies: MaskablePPO Policies -------------------- diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index 52d02d0d..31e3d340 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -125,6 +125,7 @@ Parameters :members: :inherited-members: +.. _ppo_recurrent_policies: RecurrentPPO Policies --------------------- diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index 975ca361..b06a3cb0 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -40,7 +40,7 @@ class ARS(BaseAlgorithm): :param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training. :param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses. :param n_eval_episodes: Number of episodes to evaluate each candidate. - :param policy_kwargs: Keyword arguments to pass to the policy on creation + :param policy_kwargs: Keyword arguments to pass to the policy on creation. See :ref:`ars_policies` :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: String with the directory to put tensorboard logs: diff --git a/sb3_contrib/crossq/crossq.py b/sb3_contrib/crossq/crossq.py index 6fa860f1..1b7f90b8 100644 --- a/sb3_contrib/crossq/crossq.py +++ b/sb3_contrib/crossq/crossq.py @@ -56,7 +56,7 @@ class CrossQ(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`crossq_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index f845ad6f..9fe02d0b 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -57,7 +57,7 @@ class MaskablePPO(OnPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_mask_policies` :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 060892f8..9aa74794 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -57,7 +57,7 @@ class RecurrentPPO(OnPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_recurrent_policies` :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index e2c8ac34..724a6f5c 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -53,7 +53,7 @@ class QRDQN(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`qrdqn_policies` :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index 58679dc7..99b914c6 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -60,7 +60,7 @@ class TQC(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`tqc_policies` :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index c3dc89b1..47fde551 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -64,7 +64,7 @@ class TRPO(OnPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`trpo_policies` :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. From 53baa5d668ca7673c10dca28052773b54ead0613 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 2 Dec 2024 14:48:12 +0100 Subject: [PATCH 2/2] Sort `__all__` --- sb3_contrib/__init__.py | 6 +++--- sb3_contrib/common/torch_layers.py | 2 +- sb3_contrib/ppo_mask/__init__.py | 2 +- sb3_contrib/qrdqn/__init__.py | 2 +- sb3_contrib/tqc/__init__.py | 2 +- sb3_contrib/trpo/__init__.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 5a5f6243..2aa7a19b 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -15,10 +15,10 @@ __all__ = [ "ARS", - "CrossQ", - "MaskablePPO", - "RecurrentPPO", "QRDQN", "TQC", "TRPO", + "CrossQ", + "MaskablePPO", + "RecurrentPPO", ] diff --git a/sb3_contrib/common/torch_layers.py b/sb3_contrib/common/torch_layers.py index 76a93bb7..2605441e 100644 --- a/sb3_contrib/common/torch_layers.py +++ b/sb3_contrib/common/torch_layers.py @@ -1,6 +1,6 @@ import torch -__all__ = ["BatchRenorm1d", "BatchRenorm"] +__all__ = ["BatchRenorm", "BatchRenorm1d"] class BatchRenorm(torch.nn.Module): diff --git a/sb3_contrib/ppo_mask/__init__.py b/sb3_contrib/ppo_mask/__init__.py index 89d4cedd..3d49cd08 100644 --- a/sb3_contrib/ppo_mask/__init__.py +++ b/sb3_contrib/ppo_mask/__init__.py @@ -1,4 +1,4 @@ from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO -__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"] +__all__ = ["CnnPolicy", "MaskablePPO", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/qrdqn/__init__.py b/sb3_contrib/qrdqn/__init__.py index 6f13c23d..7ce8107b 100644 --- a/sb3_contrib/qrdqn/__init__.py +++ b/sb3_contrib/qrdqn/__init__.py @@ -1,4 +1,4 @@ from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.qrdqn.qrdqn import QRDQN -__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"] +__all__ = ["QRDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/tqc/__init__.py b/sb3_contrib/tqc/__init__.py index e40a55ae..34a501aa 100644 --- a/sb3_contrib/tqc/__init__.py +++ b/sb3_contrib/tqc/__init__.py @@ -1,4 +1,4 @@ from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.tqc.tqc import TQC -__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TQC"] +__all__ = ["TQC", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/trpo/__init__.py b/sb3_contrib/trpo/__init__.py index 8d6eea78..312a20e5 100644 --- a/sb3_contrib/trpo/__init__.py +++ b/sb3_contrib/trpo/__init__.py @@ -1,4 +1,4 @@ from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.trpo.trpo import TRPO -__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TRPO"] +__all__ = ["TRPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]