From 52c29dc497fa2eb235d0476b067bed8ac488fe64 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 30 Nov 2021 13:49:06 +0100 Subject: [PATCH] Fix evaluation script for recurrent policies (#678) * Fix evaluation script for RNN * Add error message * Revert "Add error message" This reverts commit 8d69b6cf4de2cd13aecfb425bd3145fad6a6c49a. * Fix for pytype * Rename mask to `episode_start` * Fix type hint * Fix type hints * Remove confusing part of sentence Co-authored-by: Anssi --- docs/misc/changelog.rst | 5 +++-- stable_baselines3/common/base_class.py | 19 +++++++++++-------- stable_baselines3/common/evaluation.py | 6 +++--- stable_baselines3/common/policies.py | 20 +++++++++++--------- stable_baselines3/dqn/dqn.py | 10 +++++----- stable_baselines3/version.txt | 2 +- 6 files changed, 34 insertions(+), 28 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0a0668d0e..441a04ed6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,11 +4,12 @@ Changelog ========== -Release 1.3.1a2 (WIP) +Release 1.3.1a3 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only) New Features: ^^^^^^^^^^^^^ @@ -20,7 +21,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum) - FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev) - +- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4bad2c2ca..757bafee3 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -542,21 +542,24 @@ def learn( def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the model's action(s) from an observation + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ - return self.policy.predict(observation, state, mask, deterministic) + return self.policy.predict(observation, state, episode_start, deterministic) def set_random_seed(self, seed: Optional[int] = None) -> None: """ diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index ae05f4919..e3f14d3f8 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -81,8 +81,9 @@ def evaluate_policy( current_lengths = np.zeros(n_envs, dtype="int") observations = env.reset() states = None + episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, deterministic=deterministic) + actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 @@ -93,6 +94,7 @@ def evaluate_policy( reward = rewards[i] done = dones[i] info = infos[i] + episode_starts[i] = done if callback is not None: callback(locals(), globals()) @@ -116,8 +118,6 @@ def evaluate_policy( episode_counts[i] += 1 current_rewards[i] = 0 current_lengths[i] = 0 - if states is not None: - states[i] *= 0 if render: env.render() diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 9b4559274..33918b784 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -307,26 +307,28 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index d3b1eb583..11e7ac711 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -198,16 +198,16 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next state (used in recurrent policies) @@ -222,7 +222,7 @@ def predict( else: action = np.array(self.action_space.sample()) else: - action, state = self.policy.predict(observation, state, mask, deterministic) + action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state def learn( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c5813fb1b..896c1f343 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.3.1a2 +1.3.1a3