Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hindsight Experience Replay (HER) - Reloaded #273

Merged
merged 51 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a615b2a
Add bit flipping env
araffin Apr 11, 2019
5bfa61c
HER reloaded (WIP)
araffin Apr 12, 2019
7ff5208
DQN + HER
araffin Apr 15, 2019
3e67330
Add support for SAC and DDPG
araffin Apr 16, 2019
dab5647
Add tests for SAC and DDPG + HER
araffin Apr 20, 2019
9e42f1e
Bug fix + add comments
araffin Apr 20, 2019
63ffc83
Add action noise for SAC
araffin Apr 20, 2019
2e79261
Add note about pop-art normalization
araffin Apr 21, 2019
12ab42e
Merge branch 'master' into HER-2
araffin Apr 21, 2019
eb0da05
Merge branch 'master' into HER-2
araffin Apr 22, 2019
a9f43af
Add saving/loading
araffin Apr 22, 2019
ca32a5f
Add success rate
araffin Apr 22, 2019
8023bbc
Fix HER learning method
araffin Apr 23, 2019
abe17f3
Merge branch 'master' into HER-2
araffin Apr 23, 2019
09e514d
Add support for VecEnv
araffin Apr 27, 2019
c6479e4
Update documentation
araffin Apr 27, 2019
c72e760
Add HER example
araffin Apr 28, 2019
fc3d592
Merge branch 'master' into HER-2
araffin Apr 28, 2019
20fda69
Merge branch 'master' into HER-2
araffin Apr 28, 2019
36fd201
Merge branch 'master' into HER-2
araffin Apr 30, 2019
5799fd9
Merge branch 'master' into HER-2
araffin May 4, 2019
88cb4e5
Removed unused dependencies (tdqm, dill, progressbar2, seaborn, glob2…
araffin May 4, 2019
6c7f5bb
Remove note on the replay buffer
araffin May 4, 2019
65d21e2
Update doc + add a check for VecEnvWrapper with HER
araffin May 5, 2019
8723869
Update examples + add notebook for HER
araffin May 5, 2019
ea1238b
Merge branch 'master' into HER-2
araffin May 9, 2019
0a3b789
Merge branch 'master' into HER-2
araffin May 11, 2019
6ef753d
Merge branch 'master' into HER-2
araffin May 15, 2019
157b005
Merge branch 'master' into HER-2
araffin May 18, 2019
0be6f84
Add random exploration for SAC and DDPG
araffin May 19, 2019
b208889
Typo in docstring
araffin May 19, 2019
27699bf
Doc update: add fix for DDPG saved models
araffin May 19, 2019
3dfe6b1
Merge branch 'master' into HER-2
araffin May 21, 2019
87db166
Test with reward offset
araffin May 22, 2019
1a7e090
Add GoalEnvNormalize draft
araffin May 22, 2019
7592bbd
Remove GoalEnvNormalize
araffin May 23, 2019
aebdfe9
Merge branch 'master' into HER-2
araffin May 23, 2019
edfe3c3
Merge branch 'master' into HER-2
araffin May 30, 2019
730b171
Fix typo
araffin May 31, 2019
635c7d0
Bug fix for HER + VecEnv
araffin Jun 1, 2019
bf363ad
Fix HER test env
araffin Jun 1, 2019
ccbc5c7
Fixed key order
araffin Jun 1, 2019
e1e344b
Add support for discrete obs space
araffin Jun 2, 2019
096f045
Update doc about reproducing experiments
araffin Jun 2, 2019
7688838
Update doc: DDPG supports multiprocessing with MPI
araffin Jun 2, 2019
5c24590
Merge branch 'master' into HER-2
araffin Jun 2, 2019
cd18225
Fix for new abstract method
araffin Jun 2, 2019
65ef631
Update changelog
araffin Jun 2, 2019
84af166
Fix custom policy example
araffin Jun 4, 2019
e2408eb
Add replay_wrapper to base OffPolicy class
araffin Jun 4, 2019
6ed497d
Fix reimport
araffin Jun 4, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Pre-Release 2.5.1a0 (WIP)

- doc update (fix example of result plotter + improve doc)
- fixed logger issues when stdout lacks ``read`` function
- **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x)
- removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer``


Release 2.5.0 (2019-03-28)
Expand Down
1 change: 1 addition & 0 deletions stable_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from stable_baselines.acktr import ACKTR
from stable_baselines.ddpg import DDPG
from stable_baselines.deepq import DQN
from stable_baselines.her import HER
from stable_baselines.gail import GAIL
from stable_baselines.ppo1 import PPO1
from stable_baselines.ppo2 import PPO2
Expand Down
74 changes: 74 additions & 0 deletions stable_baselines/common/bit_flipping_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
from gym import GoalEnv, spaces


class BitFlippingEnv(GoalEnv):
"""
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
In the continuous variant, if the ith action component has a value > 0,
then the ith bit will be flipped.

:param n_bits: (int) Number of bits to flip
:param continuous: (bool) Wether to use the continuous version or not
:param max_steps: (int) Max number of steps, by defaults, equal to n_bits
"""
def __init__(self, n_bits=10, continuous=False, max_steps=None):
super(BitFlippingEnv, self).__init__()
# The achieved goal is determined by the current state
# here, it is a special where they are equal
self.observation_space = spaces.Dict({
'observation': spaces.MultiBinary(n_bits),
'achieved_goal': spaces.MultiBinary(n_bits),
'desired_goal': spaces.MultiBinary(n_bits)
})
if continuous:
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
else:
self.action_space = spaces.Discrete(n_bits)
self.continuous = continuous
self.state = None
self.desired_goal = np.ones((n_bits,))
if max_steps is None:
max_steps = n_bits
self.max_steps = max_steps
self.current_step = 0
self.reset()

def reset(self):
self.current_step = 0
self.state = self.observation_space.spaces['observation'].sample()
return {
'observation': self.state.copy(),
'achieved_goal': self.state.copy(),
'desired_goal': self.desired_goal.copy()
}

def step(self, action):
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
self.state[action] = 1 - self.state[action]
obs = {
'observation': self.state.copy(),
'achieved_goal': self.state.copy(),
'desired_goal': self.desired_goal.copy()
}
reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None)
done = (obs['achieved_goal'] == obs['desired_goal']).all()
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
done = done or self.current_step >= self.max_steps
return obs, reward, done, {}

def compute_reward(self, achieved_goal, desired_goal, _info):
# Deceptive reward: it is positive only when the goal is achieved
return 0 if (achieved_goal == desired_goal).all() else -1

def render(self, mode='human'):
if mode == 'rgb_array':
return self.state.copy()
print(self.state)

def close(self):
pass
111 changes: 71 additions & 40 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from stable_baselines.ddpg.policies import DDPGPolicy
from stable_baselines.common.mpi_running_mean_std import RunningMeanStd
from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger
from stable_baselines.ddpg.memory import Memory
from stable_baselines.deepq.replay_buffer import ReplayBuffer


def normalize(tensor, stats):
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev, verb

assert len(tf_util.get_globals_vars(actor)) == len(tf_util.get_globals_vars(perturbed_actor))
assert len([var for var in tf_util.get_trainable_vars(actor) if 'LayerNorm' not in var.name]) == \
len([var for var in tf_util.get_trainable_vars(perturbed_actor) if 'LayerNorm' not in var.name])
len([var for var in tf_util.get_trainable_vars(perturbed_actor) if 'LayerNorm' not in var.name])

updates = []
for var, perturbed_var in zip(tf_util.get_globals_vars(actor), tf_util.get_globals_vars(perturbed_actor)):
Expand All @@ -140,7 +140,10 @@ class DDPG(OffPolicyRLModel):
:param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param gamma: (float) the discount factor
:param memory_policy: (Memory) the replay buffer (if None, default to baselines.ddpg.memory.Memory)
:param memory_policy: (ReplayBuffer) the replay buffer
(if None, default to baselines.deepq.replay_buffer.ReplayBuffer)
.. deprecated:: 2.6.0
This parameter will be removed in a future version
:param eval_env: (Gym Environment) the evaluation environment (can be None)
:param nb_train_steps: (int) the number of training steps
:param nb_rollout_steps: (int) the number of rollout steps
Expand All @@ -163,7 +166,10 @@ class DDPG(OffPolicyRLModel):
:param reward_scale: (float) the value the reward should be scaled by
:param render: (bool) enable rendering of the environment
:param render_eval: (bool) enable rendering of the evalution environment
:param memory_limit: (int) the max number of transitions to store
:param memory_limit: (int) the max number of transitions to store, size of the replay buffer
.. deprecated:: 2.6.0
Use `buffer_size` instead.
:param buffer_size: (int) the max number of transitions to store, size of the replay buffer
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
Expand All @@ -177,17 +183,28 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n
normalize_observations=False, tau=0.001, batch_size=128, param_noise_adaption_interval=50,
normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0.,
return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1.,
render=False, render_eval=False, memory_limit=50000, verbose=0, tensorboard_log=None,
_init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False):
render=False, render_eval=False, memory_limit=None, buffer_size=50000,
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
full_tensorboard_log=False):

# TODO: replay_buffer refactoring
super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy,
super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None,
verbose=verbose, policy_base=DDPGPolicy,
requires_vec_env=False, policy_kwargs=policy_kwargs)

# Parameters.
self.gamma = gamma
self.tau = tau
self.memory_policy = memory_policy or Memory

# TODO: remove this param in v3.x.x
if memory_policy is not None:
warnings.warn("memory_policy will be removed in a future version (v3.x.x) "
"it is now ignored and replaced with ReplayBuffer", DeprecationWarning)

if memory_limit is not None:
warnings.warn("memory_limit will be removed in a future version (v3.x.x) "
"use buffer_size instead", DeprecationWarning)
buffer_size = memory_limit

self.normalize_observations = normalize_observations
self.normalize_returns = normalize_returns
self.action_noise = action_noise
Expand All @@ -209,13 +226,14 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n
self.nb_train_steps = nb_train_steps
self.nb_rollout_steps = nb_rollout_steps
self.memory_limit = memory_limit
self.buffer_size = buffer_size
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log

# init
self.graph = None
self.stats_sample = None
self.memory = None
self.replay_buffer = None
self.policy_tf = None
self.target_init_updates = None
self.target_soft_updates = None
Expand Down Expand Up @@ -265,6 +283,8 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n
self.tb_seen_steps = None

self.target_params = None
self.obs_rms_params = None
self.ret_rms_params = None

if _init_setup_model:
self.setup_model()
Expand All @@ -287,8 +307,7 @@ def setup_model(self):
with self.graph.as_default():
self.sess = tf_util.single_threaded_session(graph=self.graph)

self.memory = self.memory_policy(limit=self.memory_limit, action_shape=self.action_space.shape,
observation_shape=self.observation_space.shape)
self.replay_buffer = ReplayBuffer(self.buffer_size)

with tf.variable_scope("input", reuse=False):
# Observation normalization.
Expand Down Expand Up @@ -401,9 +420,9 @@ def setup_model(self):
self.params = find_trainable_variables("model")
self.target_params = find_trainable_variables("target")
self.obs_rms_params = [var for var in tf.global_variables()
if "obs_rms" in var.name]
if "obs_rms" in var.name]
self.ret_rms_params = [var for var in tf.global_variables()
if "ret_rms" in var.name]
if "ret_rms" in var.name]

with self.sess.as_default():
self._initialize(self.sess)
Expand Down Expand Up @@ -596,10 +615,10 @@ def _store_transition(self, obs0, action, reward, obs1, terminal1):
:param action: ([float]) the action
:param reward: (float] the reward
:param obs1: ([float] or [int]) the current observation
:param terminal1: (bool) is the episode done
:param terminal1: (bool) Whether the episode is over
"""
reward *= self.reward_scale
self.memory.append(obs0, action, reward, obs1, terminal1)
self.replay_buffer.add(obs0, action, reward, obs1, float(terminal1))
if self.normalize_observations:
self.obs_rms.update(np.array([obs0]))

Expand All @@ -613,14 +632,17 @@ def _train_step(self, step, writer, log=False):
:return: (float, float) critic loss, actor loss
"""
# Get a batch
batch = self.memory.sample(batch_size=self.batch_size)
obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size)
# Reshape to match previous behavior and placeholder shape
rewards = rewards.reshape(-1, 1)
terminals1 = terminals1.reshape(-1, 1)

if self.normalize_returns and self.enable_popart:
old_mean, old_std, target_q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_q],
feed_dict={
self.obs_target: batch['obs1'],
self.rewards: batch['rewards'],
self.terminals1: batch['terminals1'].astype('float32')
self.obs_target: obs1,
self.rewards: rewards,
self.terminals1: terminals1
})
self.ret_rms.update(target_q.flatten())
self.sess.run(self.renormalize_q_outputs_op, feed_dict={
Expand All @@ -630,18 +652,18 @@ def _train_step(self, step, writer, log=False):

else:
target_q = self.sess.run(self.target_q, feed_dict={
self.obs_target: batch['obs1'],
self.rewards: batch['rewards'],
self.terminals1: batch['terminals1'].astype('float32')
self.obs_target: obs1,
self.rewards: rewards,
self.terminals1: terminals1
})

# Get all gradients and perform a synced update.
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
td_map = {
self.obs_train: batch['obs0'],
self.actions: batch['actions'],
self.action_train_ph: batch['actions'],
self.rewards: batch['rewards'],
self.obs_train: obs0,
self.actions: actions,
self.action_train_ph: actions,
self.rewards: rewards,
self.critic_target: target_q,
self.param_noise_stddev: 0 if self.param_noise is None else self.param_noise.current_stddev
}
Expand Down Expand Up @@ -695,7 +717,14 @@ def _get_stats(self):
if self.stats_sample is None:
# Get a sample and keep that fixed for all further computations.
# This allows us to estimate the change in value for the same set of inputs.
self.stats_sample = self.memory.sample(batch_size=self.batch_size)
obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size)
self.stats_sample = {
'obs0': obs0,
'actions': actions,
'rewards': rewards,
'obs1': obs1,
'terminals1': terminals1
}

feed_dict = {
self.actions: self.stats_sample['actions']
Expand Down Expand Up @@ -730,12 +759,12 @@ def _adapt_param_noise(self):
return 0.

# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
batch = self.memory.sample(batch_size=self.batch_size)
obs0, *_ = self.replay_buffer.sample(batch_size=self.batch_size)
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
self.param_noise_stddev: self.param_noise.current_stddev,
})
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
self.obs_adapt_noise: batch['obs0'], self.obs_train: batch['obs0'],
self.obs_adapt_noise: obs0, self.obs_train: obs0,
self.param_noise_stddev: self.param_noise.current_stddev,
})

Expand All @@ -755,10 +784,13 @@ def _reset(self):
})

def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DDPG",
reset_num_timesteps=True):
reset_num_timesteps=True, replay_wrapper=None):

new_tb_log = self._init_num_timesteps(reset_num_timesteps)

if replay_wrapper is not None:
self.replay_buffer = replay_wrapper(self.replay_buffer)

with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
as writer:
self._setup_learn(seed)
Expand Down Expand Up @@ -862,7 +894,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
epoch_adaptive_distances = []
for t_train in range(self.nb_train_steps):
# Adapt param noise, if necessary.
if self.memory.nb_entries >= self.batch_size and \
if len(self.replay_buffer) >= self.batch_size and \
t_train % self.param_noise_adaption_interval == 0:
distance = self._adapt_param_noise()
epoch_adaptive_distances.append(distance)
Expand Down Expand Up @@ -1013,8 +1045,8 @@ def save(self, save_path):
"clip_norm": self.clip_norm,
"reward_scale": self.reward_scale,
"memory_limit": self.memory_limit,
"buffer_size": self.buffer_size,
"policy": self.policy,
"memory_policy": self.memory_policy,
"n_envs": self.n_envs,
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
Expand All @@ -1026,14 +1058,13 @@ def save(self, save_path):
norm_ret_params = self.sess.run(self.ret_rms_params)

params_to_save = params \
+ target_params \
+ norm_obs_params \
+ norm_ret_params
+ target_params \
+ norm_obs_params \
+ norm_ret_params
self._save_to_file(save_path,
data=data,
params=params_to_save)


@classmethod
def load(cls, load_path, env=None, **kwargs):
data, params = cls._load_from_file(load_path)
Expand All @@ -1051,9 +1082,9 @@ def load(cls, load_path, env=None, **kwargs):

restores = []
params_to_load = model.params \
+ model.target_params \
+ model.obs_rms_params \
+ model.ret_rms_params
+ model.target_params \
+ model.obs_rms_params \
+ model.ret_rms_params
for param, loaded_p in zip(params_to_load, params):
restores.append(param.assign(loaded_p))
model.sess.run(restores)
Expand Down
1 change: 0 additions & 1 deletion stable_baselines/ddpg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from stable_baselines.common.misc_util import set_global_seeds, boolean_flag
from stable_baselines.ddpg.policies import MlpPolicy, LnMlpPolicy
from stable_baselines.ddpg import DDPG
from stable_baselines.ddpg.memory import Memory
from stable_baselines.ddpg.noise import AdaptiveParamNoiseSpec, OrnsteinUhlenbeckActionNoise, NormalActionNoise


Expand Down
Loading