From b5aa9a47ceda956694b8e5ffae4d4ce6753f1d97 Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Thu, 5 Jan 2023 10:42:22 +0100 Subject: [PATCH] Deprecation of shared layers in `mlp_extractor` (#133) * Deprecation of shared layers in mlp_extractor * Fix missing import * Reformat and update tests Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 12 +++++++-- sb3_contrib/common/maskable/policies.py | 31 ++++++++++++++++++++---- sb3_contrib/common/recurrent/policies.py | 9 ++++--- sb3_contrib/ppo_mask/ppo_mask.py | 1 - sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_dict_env.py | 6 ++--- tests/test_lstm.py | 2 +- tests/test_run.py | 2 +- 9 files changed, 49 insertions(+), 18 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index dee9f863..bd61c9af 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,15 +3,22 @@ Changelog ========== -Release 1.7.0a11 (WIP) +Release 1.7.0a12 (WIP) -------------------------- +.. warning:: + + Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO. + This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]`` + will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters, please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter -- Upgraded to Stable-Baselines3 >= 1.7.0a11 +- Upgraded to Stable-Baselines3 >= 1.7.0a12 New Features: ^^^^^^^^^^^^^ @@ -30,6 +37,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ - You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua) Others: ^^^^^^^ diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index fbf597e1..3fb401a3 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -46,7 +47,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, @@ -74,12 +76,28 @@ def __init__( squash_output=False, ) + # Convert [dict()] to dict() as shared network are deprecated + if isinstance(net_arch, list) and len(net_arch) > 0: + if isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] + else: + # Note: deprecation warning will be emitted + # by the MlpExtractor constructor + pass + # Default network architecture, from stable-baselines if net_arch is None: if features_extractor_class == NatureCNN: net_arch = [] else: - net_arch = [dict(pi=[64, 64], vf=[64, 64])] + net_arch = dict(pi=[64, 64], vf=[64, 64]) self.net_arch = net_arch self.activation_fn = activation_fn @@ -95,7 +113,8 @@ def __init__( self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.make_features_extractor() # if the features extractor is not shared, there cannot be shared layers in the mlp_extractor - if len(net_arch) > 0 and not isinstance(net_arch[0], dict): + # TODO(antonin): update the check once we change net_arch behavior + if isinstance(net_arch, list) and len(net_arch) > 0: raise ValueError( "Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor" ) @@ -382,7 +401,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, @@ -436,7 +456,8 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 1dd4869b..67211735 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -66,7 +66,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -475,7 +476,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -565,7 +567,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index cbeee093..ee88f124 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -3,7 +3,6 @@ from collections import deque from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th from gym import spaces diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index a02b7e49..77ca7d32 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a11 +1.7.0a12 diff --git a/setup.py b/setup.py index ba766cdc..4faafbb6 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.7.0a11", + "stable_baselines3>=1.7.0a12", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 38757ab3..eada97d5 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -142,7 +142,7 @@ def test_dict_spaces(model_class, channel_last): kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[dict(pi=[32], vf=[32])], + net_arch=dict(pi=[32], vf=[32]), features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -191,7 +191,7 @@ def test_dict_vec_framestack(model_class, channel_last): kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[dict(pi=[32], vf=[32])], + net_arch=dict(pi=[32], vf=[32]), features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -234,7 +234,7 @@ def test_vec_normalize(model_class): kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[dict(pi=[32], vf=[32])], + net_arch=dict(pi=[32], vf=[32]), ), ) else: diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 0bfdb381..dc7cab77 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -227,7 +227,7 @@ def make_env(): max_grad_norm=1, gae_lambda=0.98, policy_kwargs=dict( - net_arch=[dict(vf=[64])], + net_arch=dict(vf=[64], pi=[]), lstm_hidden_size=64, ortho_init=False, enable_critic_lstm=True, diff --git a/tests/test_run.py b/tests/test_run.py index 52385222..6753ebb3 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -77,7 +77,7 @@ def test_trpo_params(): use_sde=True, sub_sampling_factor=4, seed=0, - policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]), + policy_kwargs=dict(net_arch=dict(pi=[32], vf=[32])), verbose=1, ) model.learn(total_timesteps=500)