Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN fixes #39

Merged
merged 9 commits into from
Oct 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Custom Policy Network
---------------------

Stable baselines provides default policy networks for images (CNNPolicies)
Stable baselines provides default policy networks (see :ref:`Policies <policies>` ) for images (CNNPolicies)
and other type of input features (MlpPolicies).
However, you can also easily define a custom architecture for the policy (or value) network:

Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Pre Release 2.0.1.a0 (WIP)
**logging and bug fixes**

- added patch fix for equal function using `gym.spaces.MultiDiscrete` and `gym.spaces.MultiBinary`
- fixes for DQN action_probability
- re-added double DQN + refactored DQN policies **breaking changes**
- replaced `async` with `async_eigen_decomp` in ACKTR/KFAC for python 3.7 compat


Expand Down
3 changes: 1 addition & 2 deletions docs/modules/acer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ Example

import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import ACER

Expand Down
3 changes: 1 addition & 2 deletions docs/modules/acktr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Example

import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import ACKTR

Expand Down
13 changes: 12 additions & 1 deletion docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ DDPG
The DDPG model does not support ``stable_baselines.common.policies`` because it uses q-value instead
of value estimation, as a result it must use its own policy models (see :ref:`ddpg_policies`).


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
LnMlpPolicy
CnnPolicy
LnCnnPolicy

Notes
-----

Expand Down Expand Up @@ -47,7 +58,7 @@ Example
import gym
import numpy as np

from stable_baselines.ddpg.policies import MlpPolicy, CnnPolicy
from stable_baselines.ddpg.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
from stable_baselines import DDPG
Expand Down
12 changes: 11 additions & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ and its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay).
The DQN model does not support ``stable_baselines.common.policies``,
as a result it must use its own policy models (see :ref:`deepq_policies`).

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
LnMlpPolicy
CnnPolicy
LnCnnPolicy

Notes
-----

Expand Down Expand Up @@ -46,7 +56,7 @@ Example
import gym

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy, CnnPolicy
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines import DQN

env = gym.make('CartPole-v1')
Expand Down
19 changes: 19 additions & 0 deletions docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@
Policy Networks
===============

Stable-baselines provides a set of default policies, that can be used with most action spaces.
If you need more control on the policy architecture, You can also create a custom policy (see :ref:`custom_policy`).

.. note::

CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
MlpLstmPolicy
MlpLnLstmPolicy
CnnPolicy
CnnLstmPolicy
CnnLnLstmPolicy


Base Classes
------------
Expand Down
3 changes: 1 addition & 2 deletions docs/modules/ppo1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ Example

import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1

Expand Down
3 changes: 1 addition & 2 deletions docs/modules/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ Example

import gym

from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, \
CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import TRPO

Expand Down
34 changes: 16 additions & 18 deletions stable_baselines/deepq/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def build_act(q_func, ob_space, ac_space, stochastic_ph, update_eps_ph, sess):

policy = q_func(sess, ob_space, ac_space, 1, 1, None)
obs_phs = (policy.obs_ph, policy.processed_x)
deterministic_actions = policy.proba_distribution.mode()
deterministic_actions = tf.argmax(policy.q_values, axis=1)

batch_size = tf.shape(policy.obs_ph)[0]
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
Expand Down Expand Up @@ -235,8 +235,8 @@ def perturb_vars(original_scope, perturbed_scope):
adaptive_policy = q_func(sess, ob_space, ac_space, 1, 1, None, obs_phs=obs_phs)
perturb_for_adaption = perturb_vars(original_scope="model", perturbed_scope="adaptive_model/model")
kl_loss = tf.reduce_sum(
tf.nn.softmax(policy.value_fn) *
(tf.log(tf.nn.softmax(policy.value_fn)) - tf.log(tf.nn.softmax(adaptive_policy.value_fn))),
tf.nn.softmax(policy.q_values) *
(tf.log(tf.nn.softmax(policy.q_values)) - tf.log(tf.nn.softmax(adaptive_policy.q_values))),
axis=-1)
mean_kl = tf.reduce_mean(kl_loss)

Expand All @@ -259,8 +259,8 @@ def update_scale():
lambda: param_noise_threshold))

# Put everything together.
perturbed_deterministic_actions = tf.argmax(perturbable_policy.value_fn, axis=1)
deterministic_actions = tf.argmax(policy.value_fn, axis=1)
perturbed_deterministic_actions = tf.argmax(perturbable_policy.q_values, axis=1)
deterministic_actions = tf.argmax(policy.q_values, axis=1)
batch_size = tf.shape(policy.obs_ph)[0]
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=n_actions, dtype=tf.int64)
Expand Down Expand Up @@ -349,7 +349,7 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
optimize the error in Bellman's equation. See the top of the file for details.
update_target: (function) copy the parameters from optimized Q function to the target Q function.
See the top of the file for details.
debug: ({str: function}) a bunch of functions to print debug data like q_values.
step_model: (DQNPolicy) Policy for evaluation
"""
n_actions = ac_space.nvec if isinstance(ac_space, MultiDiscrete) else ac_space.n
with tf.variable_scope("input", reuse=reuse):
Expand All @@ -364,23 +364,23 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
act_f, obs_phs = build_act(q_func, ob_space, ac_space, stochastic_ph, update_eps_ph, sess)

# q network evaluation
with tf.variable_scope("eval_q_func", reuse=True, custom_getter=tf_util.outer_scope_getter("eval_q_func")):
eval_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True, obs_phs=obs_phs)
with tf.variable_scope("step_model", reuse=True, custom_getter=tf_util.outer_scope_getter("step_model")):
step_model = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True, obs_phs=obs_phs)
q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name + "/model")
# target q network evalution
# target q network evaluation

with tf.variable_scope("target_q_func", reuse=False):
target_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=False)
target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope=tf.get_variable_scope().name + "/target_q_func")

# compute estimate of best possible value starting from state at t + 1
double_value_fn = None
double_q_values = None
double_obs_ph = target_policy.obs_ph
if double_q:
with tf.variable_scope("double_q", reuse=True, custom_getter=tf_util.outer_scope_getter("double_q")):
double_policy = q_func(sess, ob_space, ac_space, 1, 1, None, reuse=True)
double_value_fn = double_policy.value_fn
double_q_values = double_policy.q_values
double_obs_ph = double_policy.obs_ph

with tf.variable_scope("loss", reuse=reuse):
Expand All @@ -391,14 +391,14 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")

# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(eval_policy.value_fn * tf.one_hot(act_t_ph, n_actions), 1)
q_t_selected = tf.reduce_sum(step_model.q_values * tf.one_hot(act_t_ph, n_actions), axis=1)

# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_best_using_online_net = tf.argmax(double_value_fn, 1)
q_tp1_best = tf.reduce_sum(target_policy.value_fn * tf.one_hot(q_tp1_best_using_online_net, n_actions), 1)
q_tp1_best_using_online_net = tf.argmax(double_q_values, axis=1)
q_tp1_best = tf.reduce_sum(target_policy.q_values * tf.one_hot(q_tp1_best_using_online_net, n_actions), axis=1)
else:
q_tp1_best = tf.reduce_max(target_policy.value_fn, 1)
q_tp1_best = tf.reduce_max(target_policy.q_values, axis=1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best

# compute RHS of bellman equation
Expand Down Expand Up @@ -457,6 +457,4 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=
)
update_target = tf_util.function([], [], updates=[update_target_expr])

q_values = tf_util.function([obs_phs[0]], eval_policy.value_fn)

return act_f, train, update_target, {'q_values': q_values}
return act_f, train, update_target, step_model
36 changes: 23 additions & 13 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import tensorflow as tf
import numpy as np
import gym
Expand Down Expand Up @@ -77,8 +79,10 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000
self.graph = None
self.sess = None
self._train_step = None
self.step_model = None
self.update_target = None
self.act = None
self.proba_step = None
self.replay_buffer = None
self.beta_schedule = None
self.exploration = None
Expand All @@ -91,10 +95,16 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000

def setup_model(self):
with SetVerbosity(self.verbose):

assert not isinstance(self.action_space, gym.spaces.Box), \
"Error: DQN cannot output a gym.spaces.Box action space."
assert issubclass(self.policy, DQNPolicy), "Error: the input policy for the DQN model must be " \

# If the policy is wrap in functool.partial (e.g. to disable dueling)
# unwrap it to check the class type
if isinstance(self.policy, partial):
test_policy = self.policy.func
else:
test_policy = self.policy
assert issubclass(test_policy, DQNPolicy), "Error: the input policy for the DQN model must be " \
"an instance of DQNPolicy."

self.graph = tf.Graph()
Expand All @@ -103,7 +113,7 @@ def setup_model(self):

optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)

self.act, self._train_step, self.update_target, _ = deepq.build_train(
self.act, self._train_step, self.update_target, self.step_model = deepq.build_train(
q_func=self.policy,
ob_space=self.observation_space,
ac_space=self.action_space,
Expand All @@ -113,7 +123,7 @@ def setup_model(self):
param_noise=self.param_noise,
sess=self.sess
)

self.proba_step = self.step_model.proba_step
self.params = find_trainable_variables("deepq")

# Initialize the parameters and copy them to the target network.
Expand Down Expand Up @@ -239,13 +249,13 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_

return self

def predict(self, observation, state=None, mask=None, deterministic=False):
def predict(self, observation, state=None, mask=None, deterministic=True):
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

observation = observation.reshape((-1,) + self.observation_space.shape)
with self.sess.as_default():
actions = self.act(observation, stochastic=not deterministic)
actions, _, _ = self.step_model.step(observation, deterministic=deterministic)

if not vectorized_env:
actions = actions[0]
Expand All @@ -257,14 +267,14 @@ def action_probability(self, observation, state=None, mask=None):
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

observation = observation.reshape((-1,) + self.observation_space.shape)
actions_proba = self.proba_step(observation, state, mask)

if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions_proba = actions_proba[0]

# Get the tensor just before the softmax function in the TensorFlow graph,
# then execute the graph from the input observation to this tensor.
tensor = self.graph.get_tensor_by_name('deepq/q_func/fully_connected_2/BiasAdd:0')
if vectorized_env:
return self._softmax(self.sess.run(tensor, feed_dict={'deepq/observation:0': observation}))
else:
return self._softmax(self.sess.run(tensor, feed_dict={'deepq/observation:0': observation}))[0]
return actions_proba

def save(self, save_path):
# params
Expand Down
7 changes: 6 additions & 1 deletion stable_baselines/deepq/experiments/enjoy_mountaincar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import gym
import numpy as np

from stable_baselines.deepq import DQN

Expand All @@ -20,7 +21,11 @@ def main(args):
while not done:
if not args.no_render:
env.render()
action, _ = model.predict(obs)
# Epsilon-greedy
if np.random.random() < 0.02:
action = env.action_space.sample()
else:
action, _ = model.predict(obs, deterministic=True)
obs, rew, done, _ = env.step(action)
episode_rew += rew
print("Episode reward", episode_rew)
Expand Down
6 changes: 4 additions & 2 deletions stable_baselines/deepq/experiments/run_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from functools import partial

from stable_baselines import bench, logger
from stable_baselines.common import set_global_seeds
Expand All @@ -14,8 +15,8 @@ def main():
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--prioritized', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--dueling', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
parser.add_argument('--checkpoint-freq', type=int, default=10000)
parser.add_argument('--checkpoint-path', type=str, default=None)
Expand All @@ -26,10 +27,11 @@ def main():
env = make_atari(args.env)
env = bench.Monitor(env, logger.get_dir())
env = wrap_atari_dqn(env)
policy = partial(CnnPolicy, dueling=args.dueling == 1)

model = DQN(
env=env,
policy=CnnPolicy,
policy=policy,
learning_rate=1e-4,
buffer_size=10000,
exploration_fraction=0.1,
Expand Down
Loading