Skip to content

Commit

Permalink
Deprecation of shared layers in mlp_extractor (#133)
Browse files Browse the repository at this point in the history
* Deprecation of shared layers in mlp_extractor

* Fix missing import

* Reformat and update tests

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
AlexPasqua and araffin authored Jan 5, 2023
1 parent 7c4a249 commit b5aa9a4
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 18 deletions.
12 changes: 10 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
Changelog
==========

Release 1.7.0a11 (WIP)
Release 1.7.0a12 (WIP)
--------------------------

.. warning::

Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a11
- Upgraded to Stable-Baselines3 >= 1.7.0a12

New Features:
^^^^^^^^^^^^^
Expand All @@ -30,6 +37,7 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)

Others:
^^^^^^^
Expand Down
31 changes: 26 additions & 5 deletions sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -46,7 +47,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
Expand Down Expand Up @@ -74,12 +76,28 @@ def __init__(
squash_output=False,
)

# Convert [dict()] to dict() as shared network are deprecated
if isinstance(net_arch, list) and len(net_arch) > 0:
if isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
else:
# Note: deprecation warning will be emitted
# by the MlpExtractor constructor
pass

# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
net_arch = dict(pi=[64, 64], vf=[64, 64])

self.net_arch = net_arch
self.activation_fn = activation_fn
Expand All @@ -95,7 +113,8 @@ def __init__(
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
# TODO(antonin): update the check once we change net_arch behavior
if isinstance(net_arch, list) and len(net_arch) > 0:
raise ValueError(
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
)
Expand Down Expand Up @@ -382,7 +401,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
Expand Down Expand Up @@ -436,7 +456,8 @@ def __init__(
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
Expand Down
9 changes: 6 additions & 3 deletions sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -475,7 +476,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down Expand Up @@ -565,7 +567,8 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
# TODO(antonin): update type annotation when we remove shared network support
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
Expand Down
1 change: 0 additions & 1 deletion sb3_contrib/ppo_mask/ppo_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import deque
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
from gym import spaces
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a11
1.7.0a12
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0a11",
"stable_baselines3>=1.7.0a12",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dict_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_dict_spaces(model_class, channel_last):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_dict_vec_framestack(model_class, channel_last):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
features_extractor_kwargs=dict(cnn_output_dim=32),
),
)
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_vec_normalize(model_class):
kwargs = dict(
n_steps=128,
policy_kwargs=dict(
net_arch=[dict(pi=[32], vf=[32])],
net_arch=dict(pi=[32], vf=[32]),
),
)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def make_env():
max_grad_norm=1,
gae_lambda=0.98,
policy_kwargs=dict(
net_arch=[dict(vf=[64])],
net_arch=dict(vf=[64], pi=[]),
lstm_hidden_size=64,
ortho_init=False,
enable_critic_lstm=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_trpo_params():
use_sde=True,
sub_sampling_factor=4,
seed=0,
policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]),
policy_kwargs=dict(net_arch=dict(pi=[32], vf=[32])),
verbose=1,
)
model.learn(total_timesteps=500)
Expand Down

0 comments on commit b5aa9a4

Please sign in to comment.