Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix evaluation script for recurrent policies #678

Merged
merged 9 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Changelog
==========


Release 1.3.1a2 (WIP)
Release 1.3.1a3 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``mask`` argument of the ``predict()`` method to ``episode_start`` (used with RNN policies only)

New Features:
^^^^^^^^^^^^^
Expand All @@ -20,7 +21,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum)
- FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev)

- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)

Deprecations:
^^^^^^^^^^^^^
Expand Down
19 changes: 11 additions & 8 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,24 @@ def learn(
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the model's action(s) from an observation
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
araffin marked this conversation as resolved.
Show resolved Hide resolved
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, mask, deterministic)
return self.policy.predict(observation, state, episode_start, deterministic)

def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def evaluate_policy(
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, deterministic=deterministic)
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
Expand All @@ -93,6 +94,7 @@ def evaluate_policy(
reward = rewards[i]
done = dones[i]
info = infos[i]
episode_starts[i] = done

if callback is not None:
callback(locals(), globals())
Expand All @@ -116,8 +118,6 @@ def evaluate_policy(
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
if states is not None:
states[i] *= 0

if render:
env.render()
Expand Down
20 changes: 11 additions & 9 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,28 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)

Expand Down
10 changes: 5 additions & 5 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.

:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
Expand All @@ -222,7 +222,7 @@ def predict(
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state

def learn(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.3.1a2
1.3.1a3