Skip to content

Commit

Permalink
Added support for vector envs in evaluation (DLR-RM#447)
Browse files Browse the repository at this point in the history
* added vector env support to evaluate_policy

* fixed linting and documentation

* updated changelog

* fixed code style issue

* added tests for vec env

* fixed formatting

* renamed observations

* added comments for vector evaluation

* fixed issues

* Cleanup + bump version

* Add comment

* Fix wrong count of episodes

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people authored and leor-c committed Aug 26, 2021
1 parent 36592fe commit 2413daa
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 61 deletions.
4 changes: 2 additions & 2 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ EvalCallback
^^^^^^^^^^^^

Evaluate periodically the performance of an agent, using a separate test environment.
It will save the best model if ``best_model_save_path`` folder is specified and save the evaluations results in a numpy archive (`evaluations.npz`) if ``log_path`` folder is specified.
It will save the best model if ``best_model_save_path`` folder is specified and save the evaluations results in a numpy archive (``evaluations.npz``) if ``log_path`` folder is specified.


.. note::
Expand Down Expand Up @@ -222,7 +222,7 @@ CallbackList
^^^^^^^^^^^^

Class for chaining callbacks, they will be called sequentially.
Alternatively, you can pass directly a list of callbacks to the `learn()` method, it will be converted automatically to a ``CallbackList``.
Alternatively, you can pass directly a list of callbacks to the ``learn()`` method, it will be converted automatically to a ``CallbackList``.


.. code-block:: python
Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.1.0a9 (WIP)
Release 1.1.0a10 (WIP)
---------------------------

**Dict observation support, timeout handling and refactored HER**
Expand Down Expand Up @@ -45,6 +45,7 @@ New Features:
- Added support for image observation when using ``HER``
- Added ``replay_buffer_class`` and ``replay_buffer_kwargs`` arguments to off-policy algorithms
- Added ``kl_divergence`` helper for ``Distribution`` classes (@09tangriro)
- Added support for vector environments with ``num_envs > 1`` (@benblack769)
- Added ``wrapper_kwargs`` argument to ``make_vec_env`` (@amy12xx)

Bug Fixes:
Expand Down Expand Up @@ -696,4 +697,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769
6 changes: 0 additions & 6 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,6 @@ def __init__(
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env])

if isinstance(eval_env, VecEnv):
assert eval_env.num_envs == 1, "You must pass only one environment for evaluation"

self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
Expand Down Expand Up @@ -355,9 +352,6 @@ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any
:param globals_:
"""
info = locals_["info"]
# VecEnv: unpack
if not isinstance(info, dict):
info = info[0]

if locals_["done"]:
maybe_is_success = info.get("is_success")
Expand Down
107 changes: 62 additions & 45 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped


def evaluate_policy(
Expand All @@ -21,7 +21,10 @@ def evaluate_policy(
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
If a vector env is passed in, this divides the episodes to evaluate onto the
different elements of the vector env. This static division of work is done to
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
details and discussion.
.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
Expand All @@ -32,8 +35,7 @@ def evaluate_policy(
wrapper before anything else.
:param model: The RL agent you want to evaluate.
:param env: The gym environment. In the case of a ``VecEnv``
this must contain only one environment.
:param env: The gym environment or ``VecEnv`` environment.
:param n_eval_episodes: Number of episode to evaluate the agent
:param deterministic: Whether to use deterministic or stochastic actions
:param render: Whether to render the environment or not
Expand All @@ -52,14 +54,12 @@ def evaluate_policy(
"""
is_monitor_wrapped = False
# Avoid circular import
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.monitor import Monitor

if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
else:
is_monitor_wrapped = is_wrapped(env, Monitor)
if not isinstance(env, VecEnv):
env = DummyVecEnv([lambda: env])

is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]

if not is_monitor_wrapped and warn:
warnings.warn(
Expand All @@ -69,41 +69,58 @@ def evaluate_policy(
UserWarning,
)

episode_rewards, episode_lengths = [], []
not_reseted = True
while len(episode_rewards) < n_eval_episodes:
# Number of loops here might differ from true episodes
# played, if underlying wrappers modify episode lengths.
# Avoid double reset, as VecEnv are reset automatically.
if not isinstance(env, VecEnv) or not_reseted:
obs = env.reset()
not_reseted = False
done, state = False, None
episode_reward = 0.0
episode_length = 0
while not done:
action, state = model.predict(obs, state=state, deterministic=deterministic)
obs, reward, done, info = env.step(action)
episode_reward += reward
if callback is not None:
callback(locals(), globals())
episode_length += 1
if render:
env.render()

if is_monitor_wrapped:
# Do not trust "done" with episode endings.
# Remove vecenv stacking (if any)
if isinstance(env, VecEnv):
info = info[0]
if "episode" in info.keys():
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
else:
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)
n_envs = env.num_envs
episode_rewards = []
episode_lengths = []

episode_counts = np.zeros(n_envs, dtype="int")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")

current_rewards = np.zeros(n_envs)
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:

# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]
info = infos[i]

if callback is not None:
callback(locals(), globals())

if dones[i]:
if is_monitor_wrapped:
# Atari wrapper can send a "done" signal when
# the agent loses a life, but it does not correspond
# to the true end of episode
if "episode" in info.keys():
# Do not trust "done" with episode endings.
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
# Only increment at the real end of an episode
episode_counts[i] += 1
else:
episode_rewards.append(current_rewards[i])
episode_lengths.append(current_lengths[i])
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
if states is not None:
states[i] *= 0

if render:
env.render()

mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a9
1.1.0a10
21 changes: 19 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv


Expand Down Expand Up @@ -102,10 +102,27 @@ def select_env(model_class) -> str:
return "Pendulum-v0"


def test_eval_callback_vec_env():
# tests that eval callback does not crash when given a vector
n_eval_envs = 3
train_env = IdentityEnv()
eval_env = DummyVecEnv([lambda: IdentityEnv()] * n_eval_envs)
model = A2C("MlpPolicy", train_env, seed=0)

eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(300, callback=eval_callback)
assert eval_callback.last_mean_reward == 100.0


def test_eval_success_logging(tmp_path):
n_bits = 2
n_envs = 2
env = BitFlippingEnv(n_bits=n_bits)
eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)])
eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)] * n_envs)
eval_callback = EvalCallback(
eval_env,
eval_freq=250,
Expand Down
32 changes: 29 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,37 @@ def reset(self, **kwargs):
return self.last_obs


@pytest.mark.parametrize("n_envs", [1, 2, 5, 7])
def test_evaluate_vector_env(n_envs):
# Tests that the number of episodes evaluated is correct
n_eval_episodes = 6

env = make_vec_env("CartPole-v1", n_envs)
model = A2C("MlpPolicy", "CartPole-v1", seed=0)

class CountCallback:
def __init__(self):
self.count = 0

def __call__(self, locals_, globals_):
if locals_["done"]:
self.count += 1

count_callback = CountCallback()

evaluate_policy(model, env, n_eval_episodes, callback=count_callback)

assert count_callback.count == n_eval_episodes


@pytest.mark.parametrize("vec_env_class", [None, DummyVecEnv, SubprocVecEnv])
def test_evaluate_policy_monitors(vec_env_class):
# Make numpy warnings throw exception
np.seterr(all="raise")
# Test that results are correct with monitor environments.
# Also test VecEnvs
n_eval_episodes = 2
n_eval_episodes = 3
n_envs = 2
env_id = "CartPole-v0"
model = A2C("MlpPolicy", env_id, seed=0)

Expand All @@ -217,9 +243,9 @@ def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
env = wrapper_class(env)
else:
if with_monitor:
env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))])
env = vec_env_class([lambda: wrapper_class(Monitor(gym.make(env_id)))] * n_envs)
else:
env = vec_env_class([lambda: wrapper_class(gym.make(env_id))])
env = vec_env_class([lambda: wrapper_class(gym.make(env_id))] * n_envs)
return env

# Test that evaluation with VecEnvs works as expected
Expand Down

0 comments on commit 2413daa

Please sign in to comment.