From 87d250ca5163b50a6fd5bf91ca856853e6fdf9bb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 10:40:35 +0100 Subject: [PATCH 1/8] Fix evaluation script for RNN --- docs/misc/changelog.rst | 4 ++-- stable_baselines3/common/base_class.py | 11 +++++++---- stable_baselines3/common/evaluation.py | 6 +++--- stable_baselines3/common/policies.py | 8 +++++--- stable_baselines3/version.txt | 2 +- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2de497913..a6a4f77ac 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.3.1a2 (WIP) +Release 1.3.1a3 (WIP) --------------------------- Breaking Changes: @@ -20,7 +20,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum) - FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev) - +- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4bad2c2ca..e0e92ec10 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -545,15 +545,18 @@ def predict( state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the model's action(s) from an observation + Get the policy action and state from an observation (and optional state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) :param mask: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ return self.policy.predict(observation, state, mask, deterministic) diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index ae05f4919..6d7960d98 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -81,8 +81,9 @@ def evaluate_policy( current_lengths = np.zeros(n_envs, dtype="int") observations = env.reset() states = None + episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, deterministic=deterministic) + actions, states = model.predict(observations, state=states, mask=episode_starts, deterministic=deterministic) observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 @@ -93,6 +94,7 @@ def evaluate_policy( reward = rewards[i] done = dones[i] info = infos[i] + episode_starts[i] = done if callback is not None: callback(locals(), globals()) @@ -116,8 +118,6 @@ def evaluate_policy( episode_counts[i] += 1 current_rewards[i] = 0 current_lengths[i] = 0 - if states is not None: - states[i] *= 0 if render: env.render() diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 9b4559274..b9e74fec5 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -310,16 +310,18 @@ def predict( state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) :param mask: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ # TODO (GH/1): add support for RNN policies diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c5813fb1b..896c1f343 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.3.1a2 +1.3.1a3 From 8d69b6cf4de2cd13aecfb425bd3145fad6a6c49a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 10:53:37 +0100 Subject: [PATCH 2/8] Add error message --- stable_baselines3/common/policies.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index b9e74fec5..e22ed9e52 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -324,11 +324,14 @@ def predict( :return: the model's action and the next hidden state (used in recurrent policies) """ - # TODO (GH/1): add support for RNN policies - # if state is None: - # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] + # Support for recurrent policies will be done in SB3 Contrib + # See https://github.com/DLR-RM/stable-baselines3/issues/18 + if state is not None or mask is not None: + raise ValueError( + "You have passed a `state` or a `mask` argument to the predict() method " + "but recurrent policies aren't supported in SB3 yet (take a look at SB3 contrib for that). " + "You should probably explicitely pass `deterministic=True|False`." + ) # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) @@ -352,7 +355,7 @@ def predict( if not vectorized_env: actions = actions[0] - return actions, state + return actions, None def scale_action(self, action: np.ndarray) -> np.ndarray: """ From 2d13a701fcbd15e55bbc502c89655cb430ff0b66 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 10:54:39 +0100 Subject: [PATCH 3/8] Revert "Add error message" This reverts commit 8d69b6cf4de2cd13aecfb425bd3145fad6a6c49a. --- stable_baselines3/common/policies.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index e22ed9e52..b9e74fec5 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -324,14 +324,11 @@ def predict( :return: the model's action and the next hidden state (used in recurrent policies) """ - # Support for recurrent policies will be done in SB3 Contrib - # See https://github.com/DLR-RM/stable-baselines3/issues/18 - if state is not None or mask is not None: - raise ValueError( - "You have passed a `state` or a `mask` argument to the predict() method " - "but recurrent policies aren't supported in SB3 yet (take a look at SB3 contrib for that). " - "You should probably explicitely pass `deterministic=True|False`." - ) + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if mask is None: + # mask = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) @@ -355,7 +352,7 @@ def predict( if not vectorized_env: actions = actions[0] - return actions, None + return actions, state def scale_action(self, action: np.ndarray) -> np.ndarray: """ From b72d147df9d51422c60ca4a8f3b0c7d93894e354 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 10:57:02 +0100 Subject: [PATCH 4/8] Fix for pytype --- stable_baselines3/common/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index b9e74fec5..450675e96 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -352,7 +352,7 @@ def predict( if not vectorized_env: actions = actions[0] - return actions, state + return actions, None def scale_action(self, action: np.ndarray) -> np.ndarray: """ From a34b3c7a30d1eba7a00b6190342163b923d4db60 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 17:55:38 +0100 Subject: [PATCH 5/8] Rename mask to `episode_start` --- docs/misc/changelog.rst | 1 + stable_baselines3/common/base_class.py | 8 ++++---- stable_baselines3/common/evaluation.py | 2 +- stable_baselines3/common/policies.py | 10 +++++----- stable_baselines3/dqn/dqn.py | 6 +++--- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c0eb8199a..441a04ed6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,7 @@ Release 1.3.1a3 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only) New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e0e92ec10..c4c0940b5 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -543,23 +543,23 @@ def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action and state from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) """ - return self.policy.predict(observation, state, mask, deterministic) + return self.policy.predict(observation, state, episode_start, deterministic) def set_random_seed(self, seed: Optional[int] = None) -> None: """ diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index 6d7960d98..e3f14d3f8 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -83,7 +83,7 @@ def evaluate_policy( states = None episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, mask=episode_starts, deterministic=deterministic) + actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 450675e96..c4a8e224b 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -308,16 +308,16 @@ def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action and state from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. @@ -327,8 +327,8 @@ def predict( # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index d3b1eb583..539e0a682 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -199,7 +199,7 @@ def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ @@ -207,7 +207,7 @@ def predict( :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next state (used in recurrent policies) @@ -222,7 +222,7 @@ def predict( else: action = np.array(self.action_space.sample()) else: - action, state = self.policy.predict(observation, state, mask, deterministic) + action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state def learn( From 2a454ed4286724e89f1e242160d5f04195741a69 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 18:01:31 +0100 Subject: [PATCH 6/8] Fix type hint --- stable_baselines3/dqn/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 539e0a682..76d640b16 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -201,7 +201,7 @@ def predict( state: Optional[np.ndarray] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. From deab2eed64c7cd6a02753f4ab03c4bb3f48ae8f9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 18:12:51 +0100 Subject: [PATCH 7/8] Fix type hints --- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/policies.py | 4 ++-- stable_baselines3/dqn/dqn.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index c4c0940b5..88bf60b94 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -542,7 +542,7 @@ def learn( def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index c4a8e224b..59e1eaa36 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -307,7 +307,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -352,7 +352,7 @@ def predict( if not vectorized_env: actions = actions[0] - return actions, None + return actions, state def scale_action(self, action: np.ndarray) -> np.ndarray: """ diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 76d640b16..11e7ac711 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -198,7 +198,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: From 3a664a4e76a6eeac97944e74dcac1be5bd5967e8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 30 Nov 2021 09:59:32 +0100 Subject: [PATCH 8/8] Remove confusing part of sentence --- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/policies.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 88bf60b94..757bafee3 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -547,7 +547,7 @@ def predict( deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional hidden state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 59e1eaa36..33918b784 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -312,7 +312,7 @@ def predict( deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional hidden state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation