Skip to content

Commit

Permalink
Fix evaluation script for recurrent policies (#678)
Browse files Browse the repository at this point in the history
* Fix evaluation script for RNN

* Add error message

* Revert "Add error message"

This reverts commit 8d69b6c.

* Fix for pytype

* Rename mask to `episode_start`

* Fix type hint

* Fix type hints

* Remove confusing part of sentence

Co-authored-by: Anssi <kaneran21@hotmail.com>
  • Loading branch information
araffin and Miffyli authored Nov 30, 2021
1 parent 8e5ede7 commit 52c29dc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Changelog
==========


Release 1.3.1a2 (WIP)
Release 1.3.1a3 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only)

New Features:
^^^^^^^^^^^^^
Expand All @@ -20,7 +21,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum)
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)

- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)

Deprecations:
^^^^^^^^^^^^^
Expand Down
19 changes: 11 additions & 8 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,24 @@ def learn(
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the model's action(s) from an observation
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, mask, deterministic)
return self.policy.predict(observation, state, episode_start, deterministic)

def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def evaluate_policy(
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, deterministic=deterministic)
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
Expand All @@ -93,6 +94,7 @@ def evaluate_policy(
reward = rewards[i]
done = dones[i]
info = infos[i]
episode_starts[i] = done

if callback is not None:
callback(locals(), globals())
Expand All @@ -116,8 +118,6 @@ def evaluate_policy(
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
if states is not None:
states[i] *= 0

if render:
env.render()
Expand Down
20 changes: 11 additions & 9 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,28 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)

Expand Down
10 changes: 5 additions & 5 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
Expand All @@ -222,7 +222,7 @@ def predict(
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state

def learn(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.1a2
1.3.1a3

0 comments on commit 52c29dc

Please sign in to comment.