From a615b2aa9cfcbaef91ad28848323b024907045fa Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 11 Apr 2019 11:11:28 +0200 Subject: [PATCH 01/36] Add bit flipping env --- stable_baselines/common/bit_flipping_env.py | 75 +++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 stable_baselines/common/bit_flipping_env.py diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py new file mode 100644 index 0000000000..33ec9b947d --- /dev/null +++ b/stable_baselines/common/bit_flipping_env.py @@ -0,0 +1,75 @@ +import numpy as np +from gym import GoalEnv, spaces + + +class BitFlippingEnv(GoalEnv): + """ + Simple bit flipping env, useful to test HER. + The goal is to flip all the bits to get a vector of ones. + In the continuous variant, if the ith action component has a value > 0, + then the ith bit will be flipped. + + :param n_bits: (int) Number of bits to flip + :param continuous: (bool) Wether to use the continuous version or not + :param max_steps: (int) Max number of steps, by defaults, equal to n_bits + """ + def __init__(self, n_bits=10, continuous=False, max_steps=None): + super(BitFlippingEnv, self).__init__() + # The achieved goal is determined by the current state + # here, it is a special where they are equal + self.observation_space = spaces.Dict({ + 'observation': spaces.MultiBinary(n_bits), + 'achieved_goal': spaces.MultiBinary(n_bits), + 'desired_goal': spaces.MultiBinary(n_bits) + }) + if continuous: + self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) + else: + self.action_space = spaces.Discrete(n_bits) + self.continuous = continuous + self.state = None + self.desired_goal = np.ones((n_bits,)) + if max_steps is None: + max_steps = n_bits + self.max_steps = max_steps + self.current_step = 0 + self.reset() + + def reset(self): + self.current_step = 0 + self.state = self.observation_space.spaces['observation'].sample() + return { + 'observation': self.state.copy(), + 'achieved_goal': self.state.copy(), + 'desired_goal': self.desired_goal.copy() + } + + def step(self, action): + if self.continuous: + self.state[action > 0] = 1 - self.state[action > 0] + else: + self.state[action] = 1 - self.state[action] + obs = { + 'observation': self.state.copy(), + 'achieved_goal': self.state.copy(), + 'desired_goal': self.desired_goal.copy() + } + reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) + done = (obs['achieved_goal'] == obs['desired_goal']).all() + self.current_step += 1 + # Episode terminate when we reached the goal or the max number of steps + done = done or self.current_step >= self.max_steps + return obs, reward, done, {} + + @staticmethod + def compute_reward(achieved_goal, desired_goal, _info): + # Deceptive reward: it is positive only when the goal is achieved + return 0 if (achieved_goal == desired_goal).all() else -1 + + def render(mode='human'): + if mode == 'rgb_array': + return self.state.copy() + print(self.state) + + def close(): + pass From 5bfa61c3767d8584c21dfc2a9077ec0a48b29150 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Apr 2019 13:38:10 +0200 Subject: [PATCH 02/36] HER reloaded (WIP) --- stable_baselines/her/actor_critic.py | 52 --- stable_baselines/her/ddpg.py | 417 -------------------- stable_baselines/her/experiment/__init__.py | 0 stable_baselines/her/experiment/config.py | 215 ---------- stable_baselines/her/experiment/play.py | 69 ---- stable_baselines/her/experiment/plot.py | 141 ------- stable_baselines/her/experiment/train.py | 236 ----------- stable_baselines/her/her.py | 92 +---- stable_baselines/her/normalizer.py | 199 ---------- stable_baselines/her/replay_buffer.py | 231 +++++------ stable_baselines/her/rollout.py | 228 ----------- stable_baselines/her/util.py | 150 ------- stable_baselines/her/utils.py | 54 +++ 13 files changed, 187 insertions(+), 1897 deletions(-) delete mode 100644 stable_baselines/her/actor_critic.py delete mode 100644 stable_baselines/her/ddpg.py delete mode 100644 stable_baselines/her/experiment/__init__.py delete mode 100644 stable_baselines/her/experiment/config.py delete mode 100644 stable_baselines/her/experiment/play.py delete mode 100644 stable_baselines/her/experiment/plot.py delete mode 100644 stable_baselines/her/experiment/train.py delete mode 100644 stable_baselines/her/normalizer.py delete mode 100644 stable_baselines/her/rollout.py delete mode 100644 stable_baselines/her/util.py create mode 100644 stable_baselines/her/utils.py diff --git a/stable_baselines/her/actor_critic.py b/stable_baselines/her/actor_critic.py deleted file mode 100644 index e108b69215..0000000000 --- a/stable_baselines/her/actor_critic.py +++ /dev/null @@ -1,52 +0,0 @@ -import tensorflow as tf - -from stable_baselines.her.util import mlp - - -class ActorCritic: - def __init__(self, inputs_tf, dim_obs, dim_goal, dim_action, - max_u, o_stats, g_stats, hidden, layers, **kwargs): - """The actor-critic network and related training code. - - :param inputs_tf: ({str: TensorFlow Tensor}) all necessary inputs for the network: the - observation (o), the goal (g), and the action (u) - :param dim_obs: (int) the dimension of the observations - :param dim_goal: (int) the dimension of the goals - :param dim_action: (int) the dimension of the actions - :param max_u: (float) the maximum magnitude of actions; action outputs will be scaled accordingly - :param o_stats (stable_baselines.her.Normalizer): normalizer for observations - :param g_stats (stable_baselines.her.Normalizer): normalizer for goals - :param hidden (int): number of hidden units that should be used in hidden layers - :param layers (int): number of hidden layers - """ - self.inputs_tf = inputs_tf - self.dim_obs = dim_obs - self.dim_goal = dim_goal - self.dim_action = dim_action - self.max_u = max_u - self.o_stats = o_stats - self.g_stats = g_stats - self.hidden = hidden - self.layers = layers - - self.o_tf = inputs_tf['o'] - self.g_tf = inputs_tf['g'] - self.u_tf = inputs_tf['u'] - - # Prepare inputs for actor and critic. - obs = self.o_stats.normalize(self.o_tf) - goals = self.g_stats.normalize(self.g_tf) - input_pi = tf.concat(axis=1, values=[obs, goals]) # for actor - - # Networks. - with tf.variable_scope('pi'): - self.pi_tf = self.max_u * tf.tanh(mlp( - input_pi, [self.hidden] * self.layers + [self.dimu])) - with tf.variable_scope('Q'): - # for policy training - input_q = tf.concat(axis=1, values=[obs, goals, self.pi_tf / self.max_u]) - self.q_pi_tf = mlp(input_q, [self.hidden] * self.layers + [1]) - # for critic training - input_q = tf.concat(axis=1, values=[obs, goals, self.u_tf / self.max_u]) - self._input_q = input_q # exposed for tests - self.q_tf = mlp(input_q, [self.hidden] * self.layers + [1], reuse=True) diff --git a/stable_baselines/her/ddpg.py b/stable_baselines/her/ddpg.py deleted file mode 100644 index f023234232..0000000000 --- a/stable_baselines/her/ddpg.py +++ /dev/null @@ -1,417 +0,0 @@ -from collections import OrderedDict - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.staging import StagingArea - -from stable_baselines import logger -from stable_baselines.her.util import import_function, flatten_grads, transitions_in_episode_batch -from stable_baselines.her.normalizer import Normalizer -from stable_baselines.her.replay_buffer import ReplayBuffer -from stable_baselines.common.mpi_adam import MpiAdam - - -def dims_to_shapes(input_dims): - return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()} - - -class DDPG(object): - def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size, - q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, time_horizon, - rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return, - sample_transitions, gamma, reuse=False): - """ - Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER). - - :param input_dims: ({str: int}) dimensions for the observation (o), the goal (g), and the actions (u) - :param buffer_size: (int) number of transitions that are stored in the replay buffer - :param hidden: (int) number of units in the hidden layers - :param layers: (int) number of hidden layers - :param network_class: (str) the network class that should be used (e.g. 'stable_baselines.her.ActorCritic') - :param polyak: (float) coefficient for Polyak-averaging of the target network - :param batch_size: (int) batch size for training - :param q_lr: (float) learning rate for the Q (critic) network - :param pi_lr: (float) learning rate for the pi (actor) network - :param norm_eps: (float) a small value used in the normalizer to avoid numerical instabilities - :param norm_clip: (float) normalized inputs are clipped to be in [-norm_clip, norm_clip] - :param max_u: (float) maximum action magnitude, i.e. actions are in [-max_u, max_u] - :param action_l2: (float) coefficient for L2 penalty on the actions - :param clip_obs: (float) clip observations before normalization to be in [-clip_obs, clip_obs] - :param scope: (str) the scope used for the TensorFlow graph - :param time_horizon: (int) the time horizon for rollouts - :param rollout_batch_size: (int) number of parallel rollouts per DDPG agent - :param subtract_goals: (function (np.ndarray, np.ndarray): np.ndarray) function that subtracts goals - from each other - :param relative_goals: (boolean) whether or not relative goals should be fed into the network - :param clip_pos_returns: (boolean) whether or not positive returns should be clipped - :param clip_return: (float) clip returns to be in [-clip_return, clip_return] - :param sample_transitions: (function (dict, int): dict) function that samples from the replay buffer - :param gamma: (float) gamma used for Q learning updates - :param reuse: (boolean) whether or not the networks should be reused - """ - # Updated in experiments/config.py - self.input_dims = input_dims - self.buffer_size = buffer_size - self.hidden = hidden - self.layers = layers - self.network_class = network_class - self.polyak = polyak - self.batch_size = batch_size - self.q_lr = q_lr - self.pi_lr = pi_lr - self.norm_eps = norm_eps - self.norm_clip = norm_clip - self.max_u = max_u - self.action_l2 = action_l2 - self.clip_obs = clip_obs - self.scope = scope - self.time_horizon = time_horizon - self.rollout_batch_size = rollout_batch_size - self.subtract_goals = subtract_goals - self.relative_goals = relative_goals - self.clip_pos_returns = clip_pos_returns - self.clip_return = clip_return - self.sample_transitions = sample_transitions - self.gamma = gamma - self.reuse = reuse - - if self.clip_return is None: - self.clip_return = np.inf - - self.create_actor_critic = import_function(self.network_class) - - input_shapes = dims_to_shapes(self.input_dims) - self.dim_obs = self.input_dims['o'] - self.dim_goal = self.input_dims['g'] - self.dim_action = self.input_dims['u'] - - # Prepare staging area for feeding data to the model. - stage_shapes = OrderedDict() - for key in sorted(self.input_dims.keys()): - if key.startswith('info_'): - continue - stage_shapes[key] = (None, *input_shapes[key]) - for key in ['o', 'g']: - stage_shapes[key + '_2'] = stage_shapes[key] - stage_shapes['r'] = (None,) - self.stage_shapes = stage_shapes - - # Create network. - with tf.variable_scope(self.scope): - self.staging_tf = StagingArea( - dtypes=[tf.float32 for _ in self.stage_shapes.keys()], - shapes=list(self.stage_shapes.values())) - self.buffer_ph_tf = [ - tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()] - self.stage_op = self.staging_tf.put(self.buffer_ph_tf) - - self._create_network(reuse=reuse) - - # Configure the replay buffer. - buffer_shapes = {key: (self.time_horizon if key != 'o' else self.time_horizon + 1, *input_shapes[key]) - for key, val in input_shapes.items()} - buffer_shapes['g'] = (buffer_shapes['g'][0], self.dim_goal) - buffer_shapes['ag'] = (self.time_horizon + 1, self.dim_goal) - - buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size - self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.time_horizon, self.sample_transitions) - - def _random_action(self, num): - return np.random.uniform(low=-self.max_u, high=self.max_u, size=(num, self.dim_action)) - - def _preprocess_obs_goal(self, obs, achieved_goal, goal): - if self.relative_goals: - g_shape = goal.shape - goal = goal.reshape(-1, self.dim_goal) - achieved_goal = achieved_goal.reshape(-1, self.dim_goal) - goal = self.subtract_goals(goal, achieved_goal) - goal = goal.reshape(*g_shape) - obs = np.clip(obs, -self.clip_obs, self.clip_obs) - goal = np.clip(goal, -self.clip_obs, self.clip_obs) - return obs, goal - - def get_actions(self, obs, achieved_goal, goal, noise_eps=0., random_eps=0., use_target_net=False, compute_q=False): - """ - return the action from an observation and goal - - :param obs: (np.ndarray) the observation - :param achieved_goal: (np.ndarray) the achieved goal - :param goal: (np.ndarray) the goal - :param noise_eps: (float) the noise epsilon - :param random_eps: (float) the random epsilon - :param use_target_net: (bool) whether or not to use the target network - :param compute_q: (bool) whether or not to compute Q value - :return: (numpy float or float) the actions - """ - obs, goal = self._preprocess_obs_goal(obs, achieved_goal, goal) - policy = self.target if use_target_net else self.main - # values to compute - vals = [policy.pi_tf] - if compute_q: - vals += [policy.q_pi_tf] - # feed - feed = { - policy.o_tf: obs.reshape(-1, self.dim_obs), - policy.g_tf: goal.reshape(-1, self.dim_goal), - policy.u_tf: np.zeros((obs.size // self.dim_obs, self.dim_action), dtype=np.float32) - } - - ret = self.sess.run(vals, feed_dict=feed) - # action postprocessing - action = ret[0] - noise = noise_eps * self.max_u * np.random.randn(*action.shape) # gaussian noise - action += noise - action = np.clip(action, -self.max_u, self.max_u) - # eps-greedy - n_ac = action.shape[0] - action += np.random.binomial(1, random_eps, n_ac).reshape(-1, 1) * (self._random_action(n_ac) - action) - if action.shape[0] == 1: - action = action[0] - action = action.copy() - ret[0] = action - - if len(ret) == 1: - return ret[0] - else: - return ret - - def store_episode(self, episode_batch, update_stats=True): - """ - Story the episode transitions - - :param episode_batch: (np.ndarray) array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, - others are of size T - :param update_stats: (bool) whether to update stats or not - """ - - self.buffer.store_episode(episode_batch) - - if update_stats: - # add transitions to normalizer - episode_batch['o_2'] = episode_batch['o'][:, 1:, :] - episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] - num_normalizing_transitions = transitions_in_episode_batch(episode_batch) - transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) - - obs, _, goal, achieved_goal = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag'] - transitions['o'], transitions['g'] = self._preprocess_obs_goal(obs, achieved_goal, goal) - # No need to preprocess the o_2 and g_2 since this is only used for stats - - self.o_stats.update(transitions['o']) - self.g_stats.update(transitions['g']) - - self.o_stats.recompute_stats() - self.g_stats.recompute_stats() - - def get_current_buffer_size(self): - """ - returns the current buffer size - - :return: (int) buffer size - """ - return self.buffer.get_current_size() - - def _sync_optimizers(self): - self.q_adam.sync() - self.pi_adam.sync() - - def _grads(self): - # Avoid feed_dict here for performance! - critic_loss, actor_loss, q_grad, pi_grad = self.sess.run([ - self.q_loss_tf, - self.main.q_pi_tf, - self.q_grad_tf, - self.pi_grad_tf - ]) - return critic_loss, actor_loss, q_grad, pi_grad - - def _update(self, q_grad, pi_grad): - self.q_adam.update(q_grad, self.q_lr) - self.pi_adam.update(pi_grad, self.pi_lr) - - def sample_batch(self): - """ - sample a batch - - :return: (dict) the batch - """ - transitions = self.buffer.sample(self.batch_size) - obs, obs_2, goal = transitions['o'], transitions['o_2'], transitions['g'] - achieved_goal, achieved_goal_2 = transitions['ag'], transitions['ag_2'] - transitions['o'], transitions['g'] = self._preprocess_obs_goal(obs, achieved_goal, goal) - transitions['o_2'], transitions['g_2'] = self._preprocess_obs_goal(obs_2, achieved_goal_2, goal) - - transitions_batch = [transitions[key] for key in self.stage_shapes.keys()] - return transitions_batch - - def stage_batch(self, batch=None): - """ - apply a batch to staging - - :param batch: (dict) the batch to add to staging, if None: self.sample_batch() - """ - if batch is None: - batch = self.sample_batch() - assert len(self.buffer_ph_tf) == len(batch) - self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch))) - - def train(self, stage=True): - """ - train DDPG - - :param stage: (bool) enable staging - :return: (float, float) critic loss, actor loss - """ - if stage: - self.stage_batch() - critic_loss, actor_loss, q_grad, pi_grad = self._grads() - self._update(q_grad, pi_grad) - return critic_loss, actor_loss - - def _init_target_net(self): - self.sess.run(self.init_target_net_op) - - def update_target_net(self): - """ - update the target network - """ - self.sess.run(self.update_target_net_op) - - def clear_buffer(self): - """ - clears the replay buffer - """ - self.buffer.clear_buffer() - - def _vars(self, scope): - res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope) - assert len(res) > 0 - return res - - def _global_vars(self, scope): - res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope) - return res - - def _create_network(self, reuse=False): - logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dim_action, self.max_u)) - - self.sess = tf.get_default_session() - if self.sess is None: - self.sess = tf.InteractiveSession() - - # running averages - with tf.variable_scope('o_stats') as scope: - if reuse: - scope.reuse_variables() - self.o_stats = Normalizer(self.dim_obs, self.norm_eps, self.norm_clip, sess=self.sess) - with tf.variable_scope('g_stats') as scope: - if reuse: - scope.reuse_variables() - self.g_stats = Normalizer(self.dim_goal, self.norm_eps, self.norm_clip, sess=self.sess) - - # mini-batch sampling. - batch = self.staging_tf.get() - batch_tf = OrderedDict([(key, batch[i]) - for i, key in enumerate(self.stage_shapes.keys())]) - batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1]) - - # networks - with tf.variable_scope('main') as scope: - if reuse: - scope.reuse_variables() - self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__) - scope.reuse_variables() - with tf.variable_scope('target') as scope: - if reuse: - scope.reuse_variables() - target_batch_tf = batch_tf.copy() - target_batch_tf['o'] = batch_tf['o_2'] - target_batch_tf['g'] = batch_tf['g_2'] - self.target = self.create_actor_critic( - target_batch_tf, net_type='target', **self.__dict__) - scope.reuse_variables() - assert len(self._vars("main")) == len(self._vars("target")) - - # loss functions - target_q_pi_tf = self.target.q_pi_tf - clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf) - target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_q_pi_tf, *clip_range) - - self.q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.q_tf)) - self.pi_loss_tf = -tf.reduce_mean(self.main.q_pi_tf) - self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) - - q_grads_tf = tf.gradients(self.q_loss_tf, self._vars('main/Q')) - pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi')) - - assert len(self._vars('main/Q')) == len(q_grads_tf) - assert len(self._vars('main/pi')) == len(pi_grads_tf) - - self.q_grads_vars_tf = zip(q_grads_tf, self._vars('main/Q')) - self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi')) - self.q_grad_tf = flatten_grads(grads=q_grads_tf, var_list=self._vars('main/Q')) - self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi')) - - # optimizers - self.q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False) - self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False) - - # polyak averaging - self.main_vars = self._vars('main/Q') + self._vars('main/pi') - self.target_vars = self._vars('target/Q') + self._vars('target/pi') - self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats') - self.init_target_net_op = list( - map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars))) - self.update_target_net_op = list( - map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), - zip(self.target_vars, self.main_vars))) - - # initialize all variables - tf.variables_initializer(self._global_vars('')).run() - self._sync_optimizers() - self._init_target_net() - - def logs(self, prefix=''): - """ - create a log dictionary - :param prefix: (str) the prefix for evey index - :return: ({str: Any}) the log - """ - logs = [] - logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))] - logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))] - logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))] - logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))] - - if prefix is not '' and not prefix.endswith('/'): - return [(prefix + '/' + key, val) for key, val in logs] - else: - return logs - - def __getstate__(self): - """Our policies can be loaded from pkl, but after unpickling you cannot continue training. - """ - excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', - 'main', 'target', 'lock', 'env', 'sample_transitions', - 'stage_shapes', 'create_actor_critic'] - - state = {k: v for k, v in self.__dict__.items() if all([subname not in k for subname in excluded_subnames])} - state['buffer_size'] = self.buffer_size - state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name]) - return state - - def __setstate__(self, state): - if 'sample_transitions' not in state: - # We don't need this for playing the policy. - state['sample_transitions'] = None - - self.__init__(**state) - # set up stats (they are overwritten in __init__) - for key, value in state.items(): - if key[-6:] == '_stats': - self.__dict__[key] = value - # load TF variables - _vars = [x for x in self._global_vars('') if 'buffer' not in x.name] - assert len(_vars) == len(state["tf"]) - node = [tf.assign(var, val) for var, val in zip(_vars, state["tf"])] - self.sess.run(node) diff --git a/stable_baselines/her/experiment/__init__.py b/stable_baselines/her/experiment/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/stable_baselines/her/experiment/config.py b/stable_baselines/her/experiment/config.py deleted file mode 100644 index 19a597fe09..0000000000 --- a/stable_baselines/her/experiment/config.py +++ /dev/null @@ -1,215 +0,0 @@ -import numpy as np -import gym - -from stable_baselines import logger -from stable_baselines.her.ddpg import DDPG -from stable_baselines.her.her import make_sample_her_transitions - - -DEFAULT_ENV_PARAMS = { - 'FetchReach-v1': { - 'n_cycles': 10, - }, -} - - -DEFAULT_PARAMS = { - # env - 'max_u': 1., # max absolute value of actions on different coordinates - # ddpg - 'layers': 3, # number of layers in the critic/actor networks - 'hidden': 256, # number of neurons in each hidden layers - 'network_class': 'stable_baselines.her.actor_critic:ActorCritic', - 'q_lr': 0.001, # critic learning rate - 'pi_lr': 0.001, # actor learning rate - 'buffer_size': int(1E6), # for experience replay - 'polyak': 0.95, # polyak averaging coefficient - 'action_l2': 1.0, # quadratic penalty on actions (before rescaling by max_u) - 'clip_obs': 200., - 'scope': 'ddpg', # can be tweaked for testing - 'relative_goals': False, - # training - 'n_cycles': 50, # per epoch - 'rollout_batch_size': 2, # per mpi thread - 'n_batches': 40, # training batches per cycle - 'batch_size': 256, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length. - 'n_test_rollouts': 10, # number of test rollouts per epoch, each consists of rollout_batch_size rollouts - 'test_with_polyak': False, # run test episodes with the target network - # exploration - 'random_eps': 0.3, # percentage of time a random action is taken - 'noise_eps': 0.2, # std of gaussian noise added to not-completely-random actions as a percentage of max_u - # HER - 'replay_strategy': 'future', # supported modes: future, none - 'replay_k': 4, # number of additional goals used for replay, only used if off_policy_data=future - # normalization - 'norm_eps': 0.01, # epsilon used for observation normalization - 'norm_clip': 5, # normalized observations are cropped to this values -} - - -CACHED_ENVS = {} - - -def cached_make_env(make_env): - """ - Only creates a new environment from the provided function if one has not yet already been - created. This is useful here because we need to infer certain properties of the env, e.g. - its observation and action spaces, without any intend of actually using it. - - :param make_env: (function (): Gym Environment) creates the environment - :return: (Gym Environment) the created environment - """ - if make_env not in CACHED_ENVS: - env = make_env() - CACHED_ENVS[make_env] = env - return CACHED_ENVS[make_env] - - -def prepare_params(kwargs): - """ - prepares DDPG params from kwargs - - :param kwargs: (dict) the input kwargs - :return: (dict) DDPG parameters - """ - # DDPG params - ddpg_params = dict() - - env_name = kwargs['env_name'] - - def make_env(): - return gym.make(env_name) - kwargs['make_env'] = make_env - tmp_env = cached_make_env(kwargs['make_env']) - assert hasattr(tmp_env, '_max_episode_steps') - kwargs['time_horizon'] = tmp_env.spec.max_episode_steps # wrapped envs preserve their spec - tmp_env.reset() - kwargs['max_u'] = np.array(kwargs['max_u']) if isinstance(kwargs['max_u'], list) else kwargs['max_u'] - kwargs['gamma'] = 1. - 1. / kwargs['time_horizon'] - if 'lr' in kwargs: - kwargs['pi_lr'] = kwargs['lr'] - kwargs['q_lr'] = kwargs['lr'] - del kwargs['lr'] - for name in ['buffer_size', 'hidden', 'layers', - 'network_class', - 'polyak', - 'batch_size', 'q_lr', 'pi_lr', - 'norm_eps', 'norm_clip', 'max_u', - 'action_l2', 'clip_obs', 'scope', 'relative_goals']: - ddpg_params[name] = kwargs[name] - kwargs['_' + name] = kwargs[name] - del kwargs[name] - kwargs['ddpg_params'] = ddpg_params - - return kwargs - - -def log_params(params, logger_input=logger): - """ - log the parameters - - :param params: (dict) parameters to log - :param logger_input: (logger) the logger - """ - for key in sorted(params.keys()): - logger_input.info('{}: {}'.format(key, params[key])) - - -def configure_her(params): - """ - configure hindsight experience replay - - :param params: (dict) input parameters - :return: (function (dict, int): dict) returns a HER update function for replay buffer batch - """ - env = cached_make_env(params['make_env']) - env.reset() - - def reward_fun(achieved_goal, goal, info): # vectorized - return env.compute_reward(achieved_goal=achieved_goal, desired_goal=goal, info=info) - - # Prepare configuration for HER. - her_params = { - 'reward_fun': reward_fun, - } - for name in ['replay_strategy', 'replay_k']: - her_params[name] = params[name] - params['_' + name] = her_params[name] - del params[name] - sample_her_transitions = make_sample_her_transitions(**her_params) - - return sample_her_transitions - - -def simple_goal_subtract(vec_a, vec_b): - """ - checks if a and b have the same shape, and does a - b - - :param vec_a: (np.ndarray) - :param vec_b: (np.ndarray) - :return: (np.ndarray) a - b - """ - assert vec_a.shape == vec_b.shape - return vec_a - vec_b - - -def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True): - """ - configure a DDPG model from parameters - - :param dims: ({str: int}) the dimensions - :param params: (dict) the DDPG parameters - :param reuse: (bool) whether or not the networks should be reused - :param use_mpi: (bool) whether or not to use MPI - :param clip_return: (float) clip returns to be in [-clip_return, clip_return] - :return: (her.DDPG) the ddpg model - """ - sample_her_transitions = configure_her(params) - # Extract relevant parameters. - gamma = params['gamma'] - rollout_batch_size = params['rollout_batch_size'] - ddpg_params = params['ddpg_params'] - - input_dims = dims.copy() - - # DDPG agent - env = cached_make_env(params['make_env']) - env.reset() - ddpg_params.update({'input_dims': input_dims, # agent takes an input observations - 'time_horizon': params['time_horizon'], - 'clip_pos_returns': True, # clip positive returns - 'clip_return': (1. / (1. - gamma)) if clip_return else np.inf, # max abs of return - 'rollout_batch_size': rollout_batch_size, - 'subtract_goals': simple_goal_subtract, - 'sample_transitions': sample_her_transitions, - 'gamma': gamma, - }) - ddpg_params['info'] = { - 'env_name': params['env_name'], - } - policy = DDPG(reuse=reuse, **ddpg_params, use_mpi=use_mpi) - return policy - - -def configure_dims(params): - """ - configure input and output dimensions - - :param params: (dict) the parameters - :return: ({str: int}) the dimensions - """ - env = cached_make_env(params['make_env']) - env.reset() - obs, _, _, info = env.step(env.action_space.sample()) - - dims = { - 'o': obs['observation'].shape[0], - 'u': env.action_space.shape[0], - 'g': obs['desired_goal'].shape[0], - } - for key, value in info.items(): - value = np.array(value) - if value.ndim == 0: - value = value.reshape(1) - dims['info_{}'.format(key)] = value.shape[0] - return dims diff --git a/stable_baselines/her/experiment/play.py b/stable_baselines/her/experiment/play.py deleted file mode 100644 index 6d01e03ea1..0000000000 --- a/stable_baselines/her/experiment/play.py +++ /dev/null @@ -1,69 +0,0 @@ -import click -import pickle - -import numpy as np - -from stable_baselines import logger -from stable_baselines.common import set_global_seeds -import stable_baselines.her.experiment.config as config -from stable_baselines.her.rollout import RolloutWorker - - -@click.command() -@click.argument('policy_file', type=str) -@click.option('--seed', type=int, default=0) -@click.option('--n_test_rollouts', type=int, default=10) -@click.option('--render', type=int, default=1) -def main(policy_file, seed, n_test_rollouts, render): - """ - run HER from a saved policy - - :param policy_file: (str) pickle path to a saved policy - :param seed: (int) initial seed - :param n_test_rollouts: (int) the number of test rollouts - :param render: (bool) if rendering should be done - """ - set_global_seeds(seed) - - # Load policy. - with open(policy_file, 'rb') as file_handler: - policy = pickle.load(file_handler) - env_name = policy.info['env_name'] - - # Prepare params. - params = config.DEFAULT_PARAMS - if env_name in config.DEFAULT_ENV_PARAMS: - params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in - params['env_name'] = env_name - params = config.prepare_params(params) - config.log_params(params, logger_input=logger) - - dims = config.configure_dims(params) - - eval_params = { - 'exploit': True, - 'use_target_net': params['test_with_polyak'], - 'compute_q': True, - 'rollout_batch_size': 1, - 'render': bool(render), - } - - for name in ['time_horizon', 'gamma', 'noise_eps', 'random_eps']: - eval_params[name] = params[name] - - evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) - evaluator.seed(seed) - - # Run evaluation. - evaluator.clear_history() - for _ in range(n_test_rollouts): - evaluator.generate_rollouts() - - # record logs - for key, val in evaluator.logs('test'): - logger.record_tabular(key, np.mean(val)) - logger.dump_tabular() - - -if __name__ == '__main__': - main() diff --git a/stable_baselines/her/experiment/plot.py b/stable_baselines/her/experiment/plot.py deleted file mode 100644 index e9ee808a2e..0000000000 --- a/stable_baselines/her/experiment/plot.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import json -import argparse - -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns -import glob2 - -# Initialize seaborn -sns.set() - -def smooth_reward_curve(x, y): - """ - smooth the reward curve - - :param x: (numpy float) the x coord of the reward - :param y: (numpy float) the y coord of the reward - :return: (numpy float, numpy float) smoothed x, smoothed y - """ - halfwidth = int(np.ceil(len(x) / 60)) # Halfwidth of our smoothing convolution - k = halfwidth - xsmoo = x - ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1), - mode='same') - return xsmoo, ysmoo - - -def load_results(file): - """ - load the results from a file - - :param file: (str) the saved results - :return: (dict) the result - """ - if not os.path.exists(file): - return None - with open(file, 'r') as file_handler: - lines = [line for line in file_handler] - if len(lines) < 2: - return None - keys = [name.strip() for name in lines[0].split(',')] - data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.) - if data.ndim == 1: - data = data.reshape(1, -1) - assert data.ndim == 2 - assert data.shape[-1] == len(keys) - result = {} - for idx, key in enumerate(keys): - result[key] = data[:, idx] - return result - - -def pad(xs, value=np.nan): - """ - - - :param xs: - :param value: - :return: - """ - maxlen = np.max([len(x) for x in xs]) - - padded_xs = [] - for x in xs: - if x.shape[0] >= maxlen: - padded_xs.append(x) - - padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value - x_padded = np.concatenate([x, padding], axis=0) - assert x_padded.shape[1:] == x.shape[1:] - assert x_padded.shape[0] == maxlen - padded_xs.append(x_padded) - return np.array(padded_xs) - - -parser = argparse.ArgumentParser() -parser.add_argument('dir', type=str) -parser.add_argument('--smooth', type=int, default=1) -args = parser.parse_args() - -# Load all data. -data = {} -paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(args.dir, '**', 'progress.csv'))] -for curr_path in paths: - if not os.path.isdir(curr_path): - continue - results = load_results(os.path.join(curr_path, 'progress.csv')) - if not results: - print('skipping {}'.format(curr_path)) - continue - print('loading {} ({})'.format(curr_path, len(results['epoch']))) - with open(os.path.join(curr_path, 'params.json'), 'r') as f: - params = json.load(f) - - success_rate = np.array(results['test/success_rate']) - epoch = np.array(results['epoch']) + 1 - env_id = params['env_name'] - replay_strategy = params['replay_strategy'] - - if replay_strategy == 'future': - config = 'her' - else: - config = 'ddpg' - if 'Dense' in env_id: - config += '-dense' - else: - config += '-sparse' - env_id = env_id.replace('Dense', '') - - # Process and smooth data. - assert success_rate.shape == epoch.shape - x = epoch - y = success_rate - if args.smooth: - x, y = smooth_reward_curve(epoch, success_rate) - assert x.shape == y.shape - - if env_id not in data: - data[env_id] = {} - if config not in data[env_id]: - data[env_id][config] = [] - data[env_id][config].append((x, y)) - -# Plot data. -for env_id in sorted(data.keys()): - print('exporting {}'.format(env_id)) - plt.clf() - - for config in sorted(data[env_id].keys()): - xs, ys = zip(*data[env_id][config]) - xs, ys = pad(xs), pad(ys) - assert xs.shape == ys.shape - - plt.plot(xs[0], np.nanmedian(ys, axis=0), label=config) - plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25) - plt.title(env_id) - plt.xlabel('Epoch') - plt.ylabel('Median Success Rate') - plt.legend() - plt.savefig(os.path.join(args.dir, 'fig_{}.png'.format(env_id))) diff --git a/stable_baselines/her/experiment/train.py b/stable_baselines/her/experiment/train.py deleted file mode 100644 index 027d1dec1a..0000000000 --- a/stable_baselines/her/experiment/train.py +++ /dev/null @@ -1,236 +0,0 @@ -import os -import sys -from subprocess import CalledProcessError - -import click -import numpy as np -import json -from mpi4py import MPI - -from stable_baselines import logger -from stable_baselines.common import set_global_seeds, tf_util -from stable_baselines.common.mpi_moments import mpi_moments -import stable_baselines.her.experiment.config as config -from stable_baselines.her.rollout import RolloutWorker -from stable_baselines.her.util import mpi_fork - - -def mpi_average(value): - """ - calculate the average from the array, using MPI - - :param value: (np.ndarray) the array - :return: (float) the average - """ - if len(value) == 0: - value = [0.] - if not isinstance(value, list): - value = [value] - return mpi_moments(np.array(value))[0] - - -def train(policy, rollout_worker, evaluator, n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval, - save_policies): - """ - train the given policy - - :param policy: (her.DDPG) the policy to train - :param rollout_worker: (RolloutWorker) Rollout worker generates experience for training. - :param evaluator: (RolloutWorker) Rollout worker for evalutation - :param n_epochs: (int) the number of epochs - :param n_test_rollouts: (int) the number of for the evalutation RolloutWorker - :param n_cycles: (int) the number of cycles for training per epoch - :param n_batches: (int) the batch size - :param policy_save_interval: (int) the interval with which policy pickles are saved. - If set to 0, only the best and latest policy will be pickled. - :param save_policies: (bool) whether or not to save the policies - """ - rank = MPI.COMM_WORLD.Get_rank() - - latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl') - best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl') - periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl') - - logger.info("Training...") - best_success_rate = -1 - for epoch in range(n_epochs): - # train - rollout_worker.clear_history() - for _ in range(n_cycles): - episode = rollout_worker.generate_rollouts() - policy.store_episode(episode) - for _ in range(n_batches): - policy.train_step() - policy.update_target_net() - - # test - evaluator.clear_history() - for _ in range(n_test_rollouts): - evaluator.generate_rollouts() - - # record logs - logger.record_tabular('epoch', epoch) - for key, val in evaluator.logs('test'): - logger.record_tabular(key, mpi_average(val)) - for key, val in rollout_worker.logs('train'): - logger.record_tabular(key, mpi_average(val)) - for key, val in policy.logs(): - logger.record_tabular(key, mpi_average(val)) - - if rank == 0: - logger.dump_tabular() - - # save the policy if it's better than the previous ones - success_rate = mpi_average(evaluator.current_success_rate()) - if rank == 0 and success_rate >= best_success_rate and save_policies: - best_success_rate = success_rate - logger.info('New best success rate: {}. Saving policy to {} ...' - .format(best_success_rate, best_policy_path)) - evaluator.save_policy(best_policy_path) - evaluator.save_policy(latest_policy_path) - if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies: - policy_path = periodic_policy_path.format(epoch) - logger.info('Saving periodic policy to {} ...'.format(policy_path)) - evaluator.save_policy(policy_path) - - # make sure that different threads have different seeds - local_uniform = np.random.uniform(size=(1,)) - root_uniform = local_uniform.copy() - MPI.COMM_WORLD.Bcast(root_uniform, root=0) - if rank != 0: - assert local_uniform[0] != root_uniform[0] - - -def launch(env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return, - override_params=None, save_policies=True): - """ - launch training with mpi - - :param env: (str) environment ID - :param logdir: (str) the log directory - :param n_epochs: (int) the number of training epochs - :param num_cpu: (int) the number of CPUs to run on - :param seed: (int) the initial random seed - :param replay_strategy: (str) the type of replay strategy ('future' or 'none') - :param policy_save_interval: (int) the interval with which policy pickles are saved. - If set to 0, only the best and latest policy will be pickled. - :param clip_return: (float): clip returns to be in [-clip_return, clip_return] - :param override_params: (dict) override any parameter for training - :param save_policies: (bool) whether or not to save the policies - """ - - if override_params is None: - override_params = {} - # Fork for multi-CPU MPI implementation. - if num_cpu > 1: - try: - whoami = mpi_fork(num_cpu, ['--bind-to', 'core']) - except CalledProcessError: - # fancy version of mpi call failed, try simple version - whoami = mpi_fork(num_cpu) - - if whoami == 'parent': - sys.exit(0) - tf_util.single_threaded_session().__enter__() - rank = MPI.COMM_WORLD.Get_rank() - - # Configure logging - if rank == 0: - if logdir or logger.get_dir() is None: - logger.configure(folder=logdir) - else: - logger.configure() - logdir = logger.get_dir() - assert logdir is not None - os.makedirs(logdir, exist_ok=True) - - # Seed everything. - rank_seed = seed + 1000000 * rank - set_global_seeds(rank_seed) - - # Prepare params. - params = config.DEFAULT_PARAMS - params['env_name'] = env - params['replay_strategy'] = replay_strategy - if env in config.DEFAULT_ENV_PARAMS: - params.update(config.DEFAULT_ENV_PARAMS[env]) # merge env-specific parameters in - params.update(**override_params) # makes it possible to override any parameter - with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as file_handler: - json.dump(params, file_handler) - params = config.prepare_params(params) - config.log_params(params, logger_input=logger) - - if num_cpu == 1: - logger.warn() - logger.warn('*** Warning ***') - logger.warn( - 'You are running HER with just a single MPI worker. This will work, but the ' + - 'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' + - 'were obtained with --num_cpu 19. This makes a significant difference and if you ' + - 'are looking to reproduce those results, be aware of this. Please also refer to ' + - 'https://github.com/openai/stable_baselines/issues/314 for further details.') - logger.warn('****************') - logger.warn() - - dims = config.configure_dims(params) - policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return) - - rollout_params = { - 'exploit': False, - 'use_target_net': False, - # 'use_demo_states': True, - 'compute_q': False, - 'time_horizon': params['time_horizon'], - } - - eval_params = { - 'exploit': True, - 'use_target_net': params['test_with_polyak'], - # 'use_demo_states': False, - 'compute_q': True, - 'time_horizon': params['time_horizon'], - } - - for name in ['time_horizon', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']: - rollout_params[name] = params[name] - eval_params[name] = params[name] - - rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params) - rollout_worker.seed(rank_seed) - - evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params) - evaluator.seed(rank_seed) - - train( - policy=policy, rollout_worker=rollout_worker, - evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'], - n_cycles=params['n_cycles'], n_batches=params['n_batches'], - policy_save_interval=policy_save_interval, save_policies=save_policies) - - -@click.command() -@click.option('--env', type=str, default='FetchReach-v1', - help='the name of the OpenAI Gym environment that you want to train on') -@click.option('--logdir', type=str, default=None, - help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/') -@click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run') -@click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)') -@click.option('--seed', type=int, default=0, - help='the random seed used to seed both the environment and the training code') -@click.option('--policy_save_interval', type=int, default=5, - help='the interval with which policy pickles are saved. ' - 'If set to 0, only the best and latest policy will be pickled.') -@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', - help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.') -@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped') -def main(**kwargs): - """ - run launch for MPI HER DDPG training - - :param kwargs: (dict) the launch kwargs - """ - launch(**kwargs) - - -if __name__ == '__main__': - main() diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 7bbac13d9e..2a0f063191 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -3,98 +3,32 @@ import gym from stable_baselines.common import BaseRLModel, SetVerbosity -from stable_baselines.common.policies import LstmPolicy, ActorCriticPolicy +from .utils import HERGoalEnvWrapper -def make_sample_her_transitions(replay_strategy, replay_k, reward_fun): +class HER(BaseRLModel): """ - Creates a sample function that can be used for HER experience replay. - - :param replay_strategy: (str) the HER replay strategy; if set to 'none', regular DDPG experience replay is used - (can be 'future' or 'none'). - :param replay_k: (int) the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times - as many HER replays as regular replays are used) - :param reward_fun: (function (dict, dict): float) function to re-compute the reward with substituted goals + Hindsight Experience replay. """ - if replay_strategy == 'future': - future_p = 1 - (1. / (1 + replay_k)) - else: # 'replay_strategy' == 'none' - future_p = 0 - - def _sample_her_transitions(episode_batch, batch_size_in_transitions): - """episode_batch is {key: array(buffer_size x T x dim_key)} - """ - time_horizon = episode_batch['u'].shape[1] - rollout_batch_size = episode_batch['u'].shape[0] - batch_size = batch_size_in_transitions - - # Select which episodes and time steps to use. - episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) - t_samples = np.random.randint(time_horizon, size=batch_size) - transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() - for key in episode_batch.keys()} - - # Select future time indexes proportional with probability future_p. These - # will be used for HER replay by substituting in future goals. - her_indexes = np.where(np.random.uniform(size=batch_size) < future_p) - future_offset = np.random.uniform(size=batch_size) * (time_horizon - t_samples) - future_offset = future_offset.astype(int) - future_t = (t_samples + 1 + future_offset)[her_indexes] - - # Replace goal with achieved goal but only for the previously-selected - # HER transitions (as defined by her_indexes). For the other transitions, - # keep the original goal. - future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] - transitions['g'][her_indexes] = future_ag - - # Reconstruct info dictionary for reward computation. - info = {} - for key, value in transitions.items(): - if key.startswith('info_'): - info[key.replace('info_', '')] = value - - # Re-compute reward since we may have substituted the goal. - reward_params = {k: transitions[k] for k in ['ag_2', 'g']} - reward_params['info'] = info - transitions['r'] = reward_fun(**reward_params) - - transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) - for k in transitions.keys()} - - assert transitions['u'].shape[0] == batch_size_in_transitions - - return transitions - - return _sample_her_transitions + def __init__(self, policy, env, model_class, sampling_strategy, get_achieved_goal, + verbose=0, _init_setup_model=True): + # super().__init__(policy=policy, env=env, verbose=verbose, policy_base=None, requires_vec_env=False) + self.model_class = model_class + self.env = env + assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" + self.wrapped_env = HERGoalEnvWrapper(env) -class HER(BaseRLModel): - def __init__(self, policy, env, verbose=0, _init_setup_model=True): - super().__init__(policy=policy, env=env, verbose=verbose, policy_base=ActorCriticPolicy, requires_vec_env=False) + self.model = self.model_class(policy, self.wrapped_env) - self.policy = policy - self.sess = None - self.graph = None - - if _init_setup_model: - self.setup_model() def _get_pretrain_placeholders(self): raise NotImplementedError() - def setup_model(self): - with SetVerbosity(self.verbose): - - assert isinstance(self.action_space, gym.spaces.Box), \ - "Error: HER cannot output a {} action space, only spaces.Box is supported.".format(self.action_space) - assert not issubclass(self.policy, LstmPolicy), "Error: cannot use a recurrent policy for the HER model." - assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the HER model must be an " \ - "instance of common.policies.ActorCriticPolicy." - self.graph = tf.Graph() - with self.graph.as_default(): - pass + def setup_model(self): + pass def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True): diff --git a/stable_baselines/her/normalizer.py b/stable_baselines/her/normalizer.py deleted file mode 100644 index 4427507d15..0000000000 --- a/stable_baselines/her/normalizer.py +++ /dev/null @@ -1,199 +0,0 @@ -import threading - -import numpy as np -from mpi4py import MPI -import tensorflow as tf - -from stable_baselines.her.util import reshape_for_broadcasting - - -class Normalizer: - def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None): - """ - A normalizer that ensures that observations are approximately distributed according to - a standard Normal distribution (i.e. have mean zero and variance one). - - :param size: (int) the size of the observation to be normalized - :param eps: (float) a small constant that avoids underflows - :param default_clip_range: (float) normalized observations are clipped to be in - [-default_clip_range, default_clip_range] - :param sess: (TensorFlow Session) the TensorFlow session to be used - """ - self.size = size - self.eps = eps - self.default_clip_range = default_clip_range - self.sess = sess if sess is not None else tf.get_default_session() - - self.local_sum = np.zeros(self.size, np.float32) - self.local_sumsq = np.zeros(self.size, np.float32) - self.local_count = np.zeros(1, np.float32) - - self.sum_tf = tf.get_variable( - initializer=tf.zeros_initializer(), shape=self.local_sum.shape, name='sum', - trainable=False, dtype=tf.float32) - self.sumsq_tf = tf.get_variable( - initializer=tf.zeros_initializer(), shape=self.local_sumsq.shape, name='sumsq', - trainable=False, dtype=tf.float32) - self.count_tf = tf.get_variable( - initializer=tf.ones_initializer(), shape=self.local_count.shape, name='count', - trainable=False, dtype=tf.float32) - self.mean = tf.get_variable( - initializer=tf.zeros_initializer(), shape=(self.size,), name='mean', - trainable=False, dtype=tf.float32) - self.std = tf.get_variable( - initializer=tf.ones_initializer(), shape=(self.size,), name='std', - trainable=False, dtype=tf.float32) - self.count_pl = tf.placeholder(name='count_pl', shape=(1,), dtype=tf.float32) - self.sum_pl = tf.placeholder(name='sum_pl', shape=(self.size,), dtype=tf.float32) - self.sumsq_pl = tf.placeholder(name='sumsq_pl', shape=(self.size,), dtype=tf.float32) - - self.update_op = tf.group( - self.count_tf.assign_add(self.count_pl), - self.sum_tf.assign_add(self.sum_pl), - self.sumsq_tf.assign_add(self.sumsq_pl) - ) - self.recompute_op = tf.group( - tf.assign(self.mean, self.sum_tf / self.count_tf), - tf.assign(self.std, tf.sqrt(tf.maximum( - tf.square(self.eps), - self.sumsq_tf / self.count_tf - tf.square(self.sum_tf / self.count_tf) - ))), - ) - self.lock = threading.Lock() - - def update(self, arr): - """ - update the parameters from the input - - :param arr: (np.ndarray) the input - """ - arr = arr.reshape(-1, self.size) - - with self.lock: - self.local_sum += arr.sum(axis=0) - self.local_sumsq += (np.square(arr)).sum(axis=0) - self.local_count[0] += arr.shape[0] - - def normalize(self, arr, clip_range=None): - """ - normalize the input - - :param arr: (np.ndarray) the input - :param clip_range: (float) the range to clip to [-clip_range, clip_range] - :return: (np.ndarray) normalized input - """ - if clip_range is None: - clip_range = self.default_clip_range - mean = reshape_for_broadcasting(self.mean, arr) - std = reshape_for_broadcasting(self.std, arr) - return tf.clip_by_value((arr - mean) / std, -clip_range, clip_range) - - def denormalize(self, arr): - """ - denormalize the input - - :param arr: (np.ndarray) the normalized input - :return: (np.ndarray) original input - """ - mean = reshape_for_broadcasting(self.mean, arr) - std = reshape_for_broadcasting(self.std, arr) - return mean + arr * std - - @classmethod - def _mpi_average(cls, arr): - buf = np.zeros_like(arr) - MPI.COMM_WORLD.Allreduce(arr, buf, op=MPI.SUM) - buf /= MPI.COMM_WORLD.Get_size() - return buf - - def synchronize(self, local_sum, local_sumsq, local_count): - """ - syncronize over mpi threads - - :param local_sum: (np.ndarray) the sum - :param local_sumsq: (np.ndarray) the square root sum - :param local_count: (np.ndarray) the number of values updated - :return: (np.ndarray, np.ndarray, np.ndarray) the updated local_sum, local_sumsq, and local_count - """ - local_sum[...] = self._mpi_average(local_sum) - local_sumsq[...] = self._mpi_average(local_sumsq) - local_count[...] = self._mpi_average(local_count) - return local_sum, local_sumsq, local_count - - def recompute_stats(self): - """ - recompute the stats - """ - with self.lock: - # Copy over results. - local_count = self.local_count.copy() - local_sum = self.local_sum.copy() - local_sumsq = self.local_sumsq.copy() - - # Reset. - self.local_count[...] = 0 - self.local_sum[...] = 0 - self.local_sumsq[...] = 0 - - # We perform the synchronization outside of the lock to keep the critical section as short - # as possible. - synced_sum, synced_sumsq, synced_count = self.synchronize( - local_sum=local_sum, local_sumsq=local_sumsq, local_count=local_count) - - self.sess.run(self.update_op, feed_dict={ - self.count_pl: synced_count, - self.sum_pl: synced_sum, - self.sumsq_pl: synced_sumsq, - }) - self.sess.run(self.recompute_op) - - -class IdentityNormalizer: - def __init__(self, size, std=1.): - """ - Normalizer that returns the input unchanged - - :param size: (int or [int]) the shape of the input to normalize - :param std: (float) the initial standard deviation or the normalization - """ - self.size = size - self.mean = tf.zeros(self.size, tf.float32) - self.std = std * tf.ones(self.size, tf.float32) - - def update(self, arr): - """ - update the parameters from the input - - :param arr: (np.ndarray) the input - """ - pass - - def normalize(self, arr, **_kwargs): - """ - normalize the input - - :param arr: (np.ndarray) the input - :return: (np.ndarray) normalized input - """ - return arr / self.std - - def denormalize(self, arr): - """ - denormalize the input - - :param arr: (np.ndarray) the normalized input - :return: (np.ndarray) original input - """ - return self.std * arr - - def synchronize(self): - """ - syncronize over mpi threads - """ - pass - - def recompute_stats(self): - """ - recompute the stats - """ - pass diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index 455dc04fff..c5d865523a 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -1,132 +1,141 @@ -import threading +from enum import Enum +import copy import numpy as np +from stable_baselines.deepq.replay_buffer import ReplayBuffer -class ReplayBuffer: - def __init__(self, buffer_shapes, size_in_transitions, time_horizon, sample_transitions): - """ - Creates a replay buffer. - - :param buffer_shapes: ({str: int}) the shape for all buffers that are used in the replay buffer - :param size_in_transitions: (int) the size of the buffer, measured in transitions - :param time_horizon: (int) the time horizon for episodes - :param sample_transitions: (function) a function that samples from the replay buffer - """ - self.buffer_shapes = buffer_shapes - self.size = size_in_transitions // time_horizon - self.time_horizon = time_horizon - self.sample_transitions = sample_transitions - - # self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)} - self.buffers = {key: np.empty([self.size, *shape]) - for key, shape in buffer_shapes.items()} - # memory management - self.current_size = 0 - self.n_transitions_stored = 0 +class GoalSelectionStrategy(Enum): + FUTURE = 0 + FINAL = 1 + EPISODE = 2 + RANDOM = 3 - self.lock = threading.Lock() - - @property - def full(self): - with self.lock: - return self.current_size == self.size - - def sample(self, batch_size): - """ - sample random transitions - :param batch_size: (int) How many transitions to sample. - :return: (dict) {key: array(batch_size x shapes[key])} +class HindsightExperienceReplayBuffer(ReplayBuffer): + def __init__(self, size, n_sampled_goal, goal_selection_strategy, env): """ - buffers = {} - with self.lock: - assert self.current_size > 0 - for key in self.buffers.keys(): - buffers[key] = self.buffers[key][:self.current_size] + Inspired by https://github.com/NervanaSystems/coach/. - buffers['o_2'] = buffers['o'][:, 1:, :] - buffers['ag_2'] = buffers['ag'][:, 1:, :] - - transitions = self.sample_transitions(buffers, batch_size) - - for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())): - assert key in transitions, "key %s missing from transitions" % key - - return transitions - - def store_episode(self, episode_batch): + :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old + memories are dropped. + :param n_sampled_goal: The number of artificial transitions to generate for each actual transition + :param goal_selection_strategy: The method that will be used for generating the goals for the + hindsight transitions. Should be one of GoalSelectionStrategy + :param env: """ - Store an episode in the replay buffer - - :param episode_batch: (np.ndarray) batch_size x (T or T+1) x dim_key + super(HER, self).__init__(size) + self.n_sampled_goal = n_sampled_goal + self.goal_selection_strategy = goal_selection_strategy + self.env = env + self.current_episode = [] + self.get_achieved_goal = None + + def add(self, obs_t, action, reward, obs_tp1, done): """ - batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()] - assert np.all(np.array(batch_sizes) == batch_sizes[0]) - batch_size = batch_sizes[0] - - with self.lock: - idxs = self._get_storage_idx(batch_size) + add a new transition to the buffer - # load inputs into buffers - for key in self.buffers.keys(): - self.buffers[key][idxs] = episode_batch[key] - - self.n_transitions_stored += batch_size * self.time_horizon - - def get_current_episode_size(self): - """ - get current episode size - - :return: (int) the current size of the episode - """ - with self.lock: - return self.current_size - - def get_current_size(self): + :param obs_t: (Any) the last observation + :param action: ([float]) the action + :param reward: (float) the reward of the transition + :param obs_tp1: (Any) the current observation + :param done: (bool) is the episode done """ - get current size of the buffer - - :return: (int) the current size of the buffer + self.current_episode.append((obs_t, action, reward, obs_tp1, done)) + if done: + # Add transitions (and imagined ones) to buffer only when an episode is over + self._store_episode() + self.current_episode = [] + + # def _encode_sample(self, idxes): + # obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] + # for i in idxes: + # data = self._storage[i] + # obs_t, action, reward, obs_tp1, done = data + # obses_t.append(np.array(obs_t, copy=False)) + # actions.append(np.array(action, copy=False)) + # rewards.append(reward) + # obses_tp1.append(np.array(obs_tp1, copy=False)) + # dones.append(done) + # return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) + # + + def _sample_goal(self, episode_transitions, transition_idx): """ - with self.lock: - return self.current_size * self.time_horizon + Sample an achieved goal according to the sampling method. - def get_transitions_stored(self): + :param episode_transitions: a list of all the transitions in the current episode + :param transition_idx: the transition to start sampling from + :return: (np.ndarray) an achieved goal """ - get the number of stored transitions + if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: + # Sample a goal that was observed in the same episode after the current step + selected_obs = np.random.choice(episode_transitions[transition_idx + 1:]) + elif self.goal_selection_strategy == GoalSelectionStrategy.FINAL: + # The achieved goal at the end of the episode + selected_obs = episode_transitions[-1] + elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: + # Random achieved goal during the episode + selected_obs = np.random.choice(episode_transitions) + elif self.goal_selection_strategy == GoalSelectionStrategy.RANDOM: + # Random achieved goal from the entire replay buffer + # selected_obs = np.random.choice(self._storage) + raise NotImplementedError() + else: + raise ValueError("Invalid goal selection strategy," + "please use one of {}".format(list(GoalSelectionStrategy))) + return self.get_achieved_goal(selected_obs) - :return: (int) the number of transitions stored + def _sample_goals(self, episode_transitions, transition_idx): """ - with self.lock: - return self.n_transitions_stored + Sample a batch of achieved goal according to the sampling strategy. - def clear_buffer(self): - """ - clear the buffer of all entries + :param episode_transitions: () a list of all the transitions in the current episode + :param transition_idx: the transition to start sampling from + :return: a goal corresponding to the sampled obs """ - with self.lock: - self.current_size = 0 - - def _get_storage_idx(self, inc=None): - inc = inc or 1 # size increment - assert inc <= self.size, "Batch committed to replay is too large!" - # go consecutively until you hit the end, and then go randomly. - if self.current_size + inc <= self.size: - idx = np.arange(self.current_size, self.current_size + inc) - elif self.current_size < self.size: - overflow = inc - (self.size - self.current_size) - idx_a = np.arange(self.current_size, self.size) - idx_b = np.random.randint(0, self.current_size, overflow) - idx = np.concatenate([idx_a, idx_b]) - else: - idx = np.random.randint(0, self.size, inc) - - # update replay size - self.current_size = min(self.size, self.current_size + inc) - - if inc == 1: - idx = idx[0] - return idx + return [ + self._sample_goal(episode_transitions, transition_idx) + for _ in range(self.n_sampled_goal) + ] + + def _store_episode(self): + last_episode_transitions = copy.deepcopy(self.current_episode) + + # for each transition in the last episode, create a set of hindsight transitions + for transition_idx, transition in enumerate(last_episode_transitions): + + obs_t, action, reward, obs_tp1, done = transition + # Add to the replay buffer + super().add(obs_t, action, reward, obs_tp1, done) + # We cannot sample a goal from the future in the last step of an episode + if (transition_idx == len(last_episode_transitions) - 1 and + self.goal_selection_strategy == GoalSelectionStrategy.FUTURE): + break + + # Sampled n goals per transition, where n is `n_sampled_goal` + # this is called k in the paper + sampled_goals = self._sample_goals(last_episode_transitions, transition_idx) + # For each sampled goals, store a new transition + for goal in sampled_goals: + obs, action, reward, next_obs, done = copy.deepcopy(transition) + + # Convert concatenated obs to dict + # so we can update the goals + obs_dict, next_obs_dict = map(self.env.convert_obs_to_dict, (obs, next_obs)) + + # update the desired goal in the transition + obs_dict['desired_goal'] = goal + next_obs_dict['desired_goal'] = goal + + # update the reward and terminal signal according to the goal + reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal']) + # Can we ensure that done = reward == 0 + done = False + + obs, next_obs = map(self.env.convert_dict_to_obs, (obs_dict, next_obs_dict)) + + # Add to the replay buffer + super().add(obs, action, reward, next_obs, done) diff --git a/stable_baselines/her/rollout.py b/stable_baselines/her/rollout.py deleted file mode 100644 index aa85e719f2..0000000000 --- a/stable_baselines/her/rollout.py +++ /dev/null @@ -1,228 +0,0 @@ -from collections import deque -import pickle - -import numpy as np -from mujoco_py import MujocoException - -from stable_baselines.her.util import convert_episode_to_batch_major - - -class RolloutWorker: - def __init__(self, make_env, policy, dims, logger, time_horizon, rollout_batch_size=1, - exploit=False, use_target_net=False, compute_q=False, noise_eps=0, - random_eps=0, history_len=100, render=False): - """ - Rollout worker generates experience by interacting with one or many environments. - - :param make_env: (function (): Gym Environment) a factory function that creates a new instance of the - environment when called - :param policy: (Object) the policy that is used to act - :param dims: ({str: int}) the dimensions for observations (o), goals (g), and actions (u) - :param logger: (Object) the logger that is used by the rollout worker - :param rollout_batch_size: (int) the number of parallel rollouts that should be used - :param exploit: (bool) whether or not to exploit, i.e. to act optimally according to the current policy without - any exploration - :param use_target_net: (bool) whether or not to use the target net for rollouts - :param compute_q: (bool) whether or not to compute the Q values alongside the actions - :param noise_eps: (float) scale of the additive Gaussian noise - :param random_eps: (float) probability of selecting a completely random action - :param history_len: (int) length of history for statistics smoothing - :param render: (boolean) whether or not to render the rollouts - """ - self.make_env = make_env - self.policy = policy - self.dims = dims - self.logger = logger - self.time_horizon = time_horizon - self.rollout_batch_size = rollout_batch_size - self.exploit = exploit - self.use_target_net = use_target_net - self.compute_q = compute_q - self.noise_eps = noise_eps - self.random_eps = random_eps - self.history_len = history_len - self.render = render - - self.envs = [make_env() for _ in range(rollout_batch_size)] - assert self.time_horizon > 0 - - self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')] - - self.success_history = deque(maxlen=history_len) - self.q_history = deque(maxlen=history_len) - - self.n_episodes = 0 - self.goals = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals - self.initial_obs = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations - self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals - self.reset_all_rollouts() - self.clear_history() - - def reset_rollout(self, index): - """ - Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o` and `g` arrays - accordingly. - - :param index: (int) the index to reset - """ - obs = self.envs[index].reset() - self.initial_obs[index] = obs['observation'] - self.initial_ag[index] = obs['achieved_goal'] - self.goals[index] = obs['desired_goal'] - - def reset_all_rollouts(self): - """ - Resets all `rollout_batch_size` rollout workers. - """ - for step in range(self.rollout_batch_size): - self.reset_rollout(step) - - def generate_rollouts(self): - """ - Performs `rollout_batch_size` rollouts in parallel for time horizon with the current - policy acting on it accordingly. - - :return: (dict) batch - """ - self.reset_all_rollouts() - - # compute observations - observations = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations - achieved_goals = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals - observations[:] = self.initial_obs - achieved_goals[:] = self.initial_ag - - # generate episodes - obs, achieved_goals, acts, goals, successes = [], [], [], [], [] - info_values = [np.empty((self.time_horizon, self.rollout_batch_size, self.dims['info_' + key]), np.float32) - for key in self.info_keys] - q_values = [] - for step in range(self.time_horizon): - policy_output = self.policy.get_actions( - observations, achieved_goals, self.goals, - compute_q=self.compute_q, - noise_eps=self.noise_eps if not self.exploit else 0., - random_eps=self.random_eps if not self.exploit else 0., - use_target_net=self.use_target_net) - - if self.compute_q: - action, q_value = policy_output - q_values.append(q_value) - else: - action = policy_output - - if action.ndim == 1: - # The non-batched case should still have a reasonable shape. - action = action.reshape(1, -1) - - o_new = np.empty((self.rollout_batch_size, self.dims['o'])) - ag_new = np.empty((self.rollout_batch_size, self.dims['g'])) - success = np.zeros(self.rollout_batch_size) - # compute new states and observations - for batch_idx in range(self.rollout_batch_size): - try: - # We fully ignore the reward here because it will have to be re-computed - # for HER. - curr_o_new, _, _, info = self.envs[batch_idx].step(action[batch_idx]) - if 'is_success' in info: - success[batch_idx] = info['is_success'] - o_new[batch_idx] = curr_o_new['observation'] - ag_new[batch_idx] = curr_o_new['achieved_goal'] - for idx, key in enumerate(self.info_keys): - info_values[idx][step, batch_idx] = info[key] - if self.render: - self.envs[batch_idx].render() - except MujocoException: - return self.generate_rollouts() - - if np.isnan(o_new).any(): - self.logger.warning('NaN caught during rollout generation. Trying again...') - self.reset_all_rollouts() - return self.generate_rollouts() - - obs.append(observations.copy()) - achieved_goals.append(achieved_goals.copy()) - successes.append(success.copy()) - acts.append(action.copy()) - goals.append(self.goals.copy()) - observations[...] = o_new - achieved_goals[...] = ag_new - obs.append(observations.copy()) - achieved_goals.append(achieved_goals.copy()) - self.initial_obs[:] = observations - - episode = dict(o=obs, - u=acts, - g=goals, - ag=achieved_goals) - for key, value in zip(self.info_keys, info_values): - episode['info_{}'.format(key)] = value - - # stats - successful = np.array(successes)[-1, :] - assert successful.shape == (self.rollout_batch_size,) - success_rate = np.mean(successful) - self.success_history.append(success_rate) - - if self.compute_q: - self.q_history.append(np.mean(q_values)) - self.n_episodes += self.rollout_batch_size - - return convert_episode_to_batch_major(episode) - - def clear_history(self): - """ - Clears all histories that are used for statistics - """ - self.success_history.clear() - self.q_history.clear() - - def current_success_rate(self): - """ - returns the current success rate - :return: (float) the success rate - """ - return np.mean(self.success_history) - - def current_mean_q(self): - """ - returns the current mean Q value - :return: (float) the mean Q value - """ - return np.mean(self.q_history) - - def save_policy(self, path): - """ - Pickles the current policy for later inspection. - - :param path: (str) the save location - """ - with open(path, 'wb') as file_handler: - pickle.dump(self.policy, file_handler) - - def logs(self, prefix='worker'): - """ - Generates a dictionary that contains all collected statistics. - - :param prefix: (str) the prefix for the name in logging - :return: ([(str, float)]) the logging information - """ - logs = [] - logs += [('success_rate', np.mean(self.success_history))] - if self.compute_q: - logs += [('mean_q', np.mean(self.q_history))] - logs += [('episode', self.n_episodes)] - - if prefix is not '' and not prefix.endswith('/'): - return [(prefix + '/' + key, val) for key, val in logs] - else: - return logs - - def seed(self, seed): - """ - Seeds each environment with a distinct seed derived from the passed in global seed. - - :param seed: (int) the random seed - """ - for idx, env in enumerate(self.envs): - env.seed(seed + 1000 * idx) diff --git a/stable_baselines/her/util.py b/stable_baselines/her/util.py deleted file mode 100644 index c5a7088981..0000000000 --- a/stable_baselines/her/util.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import subprocess -import sys -import importlib - -import tensorflow as tf -import numpy as np -from mpi4py import MPI - -from stable_baselines.common import tf_util - - -def import_function(spec): - """ - Import a function identified by a string like "pkg.module:fn_name". - - :param spec: (str) the function to import - :return: (function) - """ - mod_name, fn_name = spec.split(':') - module = importlib.import_module(mod_name) - func = getattr(module, fn_name) - return func - - -def flatten_grads(var_list, grads): - """ - Flattens a variables and their gradients. - - :param var_list: ([TensorFlow Tensor]) the variables - :param grads: ([TensorFlow Tensor]) the gradients - :return: (TensorFlow Tensor) the flattend variable and gradient - """ - return tf.concat([tf.reshape(grad, [tf_util.numel(v)]) - for (v, grad) in zip(var_list, grads)], 0) - - -def mlp(_input, layers_sizes, reuse=None, flatten=False, name=""): - """ - Creates a simple fully-connected neural network - - :param _input: (TensorFlow Tensor) the input - :param layers_sizes: ([int]) the hidden layers - :param reuse: (bool) Enable reuse of the network - :param flatten: (bool) flatten the network output - :param name: (str) the name of the network - :return: (TensorFlow Tensor) the network - """ - for i, size in enumerate(layers_sizes): - activation = tf.nn.relu if i < len(layers_sizes) - 1 else None - _input = tf.layers.dense(inputs=_input, - units=size, - kernel_initializer=tf.contrib.layers.xavier_initializer(), - reuse=reuse, - name=name + '_' + str(i)) - if activation: - _input = activation(_input) - if flatten: - assert layers_sizes[-1] == 1 - _input = tf.reshape(_input, [-1]) - return _input - - -def install_mpi_excepthook(): - """ - setup the MPI exception hooks - """ - old_hook = sys.excepthook - - def new_hook(a, b, c): - old_hook(a, b, c) - sys.stdout.flush() - sys.stderr.flush() - MPI.COMM_WORLD.Abort() - - sys.excepthook = new_hook - - -def mpi_fork(rank, extra_mpi_args=None): - """ - Re-launches the current script with workers - Returns "parent" for original parent, "child" for MPI children - - :param rank: (int) the thread rank - :param extra_mpi_args: (dict) extra arguments for MPI - :return: (str) the correct type of thread name - """ - if extra_mpi_args is None: - extra_mpi_args = [] - - if rank <= 1: - return "child" - if os.getenv("IN_MPI") is None: - env = os.environ.copy() - env.update( - MKL_NUM_THREADS="1", - OMP_NUM_THREADS="1", - IN_MPI="1" - ) - # "-bind-to core" is crucial for good performance - args = ["mpirun", "-np", str(rank)] + \ - extra_mpi_args + \ - [sys.executable] - - args += sys.argv - subprocess.check_call(args, env=env) - return "parent" - else: - install_mpi_excepthook() - return "child" - - -def convert_episode_to_batch_major(episode): - """ - Converts an episode to have the batch dimension in the major (first) dimension. - - :param episode: (dict) the episode batch - :return: (dict) the episode batch with he batch dimension in the major (first) dimension. - """ - episode_batch = {} - for key in episode.keys(): - val = np.array(episode[key]).copy() - # make inputs batch-major instead of time-major - episode_batch[key] = val.swapaxes(0, 1) - - return episode_batch - - -def transitions_in_episode_batch(episode_batch): - """ - Number of transitions in a given episode batch. - - :param episode_batch: (dict) the episode batch - :return: (int) the number of transitions in episode batch - """ - shape = episode_batch['u'].shape - return shape[0] * shape[1] - - -def reshape_for_broadcasting(source, target): - """ - Reshapes a tensor (source) to have the correct shape and dtype of the target before broadcasting it with MPI. - - :param source: (TensorFlow Tensor) the input tensor - :param target: (TensorFlow Tensor) the target tensor - :return: (TensorFlow Tensor) the rehshaped tensor - """ - dim = len(target.get_shape()) - shape = ([1] * (dim - 1)) + [-1] - return tf.reshape(tf.cast(source, target.dtype), shape) diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py new file mode 100644 index 0000000000..5ef32e8545 --- /dev/null +++ b/stable_baselines/her/utils.py @@ -0,0 +1,54 @@ +import numpy as np +from gym import spaces + +class HERGoalEnvWrapper(object): + """docstring for HERGoalEnvWrapper.""" + + def __init__(self, env): + super(HERGoalEnvWrapper, self).__init__() + self.env = env + self.action_space = env.action_space + self.spaces = env.observation_space.spaces.values() + # TODO: check that all spaces are of the same type + # (current limiation of the wrapper) + # TODO: check when dim > 1 + self.obs_dim = env.observation_space.spaces['observation'].shape[0] + self.goal_dim = env.observation_space.spaces['achieved_goal'].shape[0] + total_dim = self.obs_dim + 2 * self.goal_dim + + if isinstance(self.spaces[0], spaces.MultiBinary): + self.observation_space = spaces.MultiBinary(total_dim) + elif isinstance(self.spaces[0], spaces.Box): + # total_dim = np.sum([space.shape[0] for space in self.spaces]) + # self.observation_space = spaces.Box(-np.inf, np.inf, shape=(total_dim, ), dtype=np.float32) + raise NotImplementedError() + else: + raise NotImplementedError() + + @staticmethod + def convert_dict_to_obs(obs_dict): + # Note: we should remove achieved goal from the observation ? + return np.concatenate([obs for obs in obs_dict.values()]) + + def convert_obs_to_dict(self, observations): + return { + 'observation': observations[:self.obs_dim], + 'achieved_goal': observations[self.obs_dim:self.obs_dim + self.goal_dim], + 'desired_goal': observations[self.obs_dim + self.goal_dim:], + } + + def step(self, action): + obs, reward, done, info = self.env.step(action) + return self.convert_dict_to_obs(obs), reward, done, info + + def seed(self, seed=None): + return self.env.seed(seed) + + def reset(self): + return self.convert_dict_to_obs(self.env.reset()) + + def render(self, mode='human'): + return self.env.render(mode) + + def close(self): + self.env.close() From 7ff52088d5008a26f199d7d3bb1ac635ae4b465f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 15 Apr 2019 13:31:10 +0200 Subject: [PATCH 03/36] DQN + HER --- stable_baselines/__init__.py | 1 + stable_baselines/common/bit_flipping_env.py | 7 +- stable_baselines/deepq/dqn.py | 8 ++- stable_baselines/her/__init__.py | 1 + stable_baselines/her/her.py | 30 +++++---- stable_baselines/her/replay_buffer.py | 72 ++++++++++----------- stable_baselines/her/utils.py | 33 ++++++---- tests/test_her.py | 15 +++++ 8 files changed, 98 insertions(+), 69 deletions(-) create mode 100644 tests/test_her.py diff --git a/stable_baselines/__init__.py b/stable_baselines/__init__.py index 5963f79b85..9c96fc13e9 100644 --- a/stable_baselines/__init__.py +++ b/stable_baselines/__init__.py @@ -3,6 +3,7 @@ from stable_baselines.acktr import ACKTR from stable_baselines.ddpg import DDPG from stable_baselines.deepq import DQN +from stable_baselines.her import HER from stable_baselines.gail import GAIL from stable_baselines.ppo1 import PPO1 from stable_baselines.ppo2 import PPO2 diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py index 33ec9b947d..571ed9d861 100644 --- a/stable_baselines/common/bit_flipping_env.py +++ b/stable_baselines/common/bit_flipping_env.py @@ -61,15 +61,14 @@ def step(self, action): done = done or self.current_step >= self.max_steps return obs, reward, done, {} - @staticmethod - def compute_reward(achieved_goal, desired_goal, _info): + def compute_reward(self, achieved_goal, desired_goal, _info): # Deceptive reward: it is positive only when the goal is achieved return 0 if (achieved_goal == desired_goal).all() else -1 - def render(mode='human'): + def render(self, mode='human'): if mode == 'rgb_array': return self.state.copy() print(self.state) - def close(): + def close(self): pass diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index d56dd34a8b..86d7ba4cdc 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -143,7 +143,7 @@ def setup_model(self): self.summary = tf.summary.merge_all() def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN", - reset_num_timesteps=True): + reset_num_timesteps=True, replay_wrapper=None): new_tb_log = self._init_num_timesteps(reset_num_timesteps) @@ -164,6 +164,12 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ else: self.replay_buffer = ReplayBuffer(self.buffer_size) self.beta_schedule = None + + if replay_wrapper is not None: + assert not self.prioritized_replay, "Prioritized replay buffer is not supported by HER" + self.replay_buffer = replay_wrapper(self.replay_buffer) + + # Create the schedule for exploration starting from 1. self.exploration = LinearSchedule(schedule_timesteps=int(self.exploration_fraction * total_timesteps), initial_p=1.0, diff --git a/stable_baselines/her/__init__.py b/stable_baselines/her/__init__.py index 4c28812c8a..0c7dd57c64 100644 --- a/stable_baselines/her/__init__.py +++ b/stable_baselines/her/__init__.py @@ -1 +1,2 @@ from stable_baselines.her.her import HER +from stable_baselines.her.replay_buffer import GoalSelectionStrategy diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 2a0f063191..89358b4d20 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -1,8 +1,9 @@ -import tensorflow as tf -import numpy as np +import functools + import gym -from stable_baselines.common import BaseRLModel, SetVerbosity +from stable_baselines.common import BaseRLModel +from .replay_buffer import HindsightExperienceReplayWrapper, KEY_TO_GOAL_STRATEGY from .utils import HERGoalEnvWrapper @@ -10,32 +11,33 @@ class HER(BaseRLModel): """ Hindsight Experience replay. """ - def __init__(self, policy, env, model_class, sampling_strategy, get_achieved_goal, - verbose=0, _init_setup_model=True): + + def __init__(self, policy, env, model_class, n_sampled_goal=4, + goal_selection_strategy='future', *args, **kwargs): # super().__init__(policy=policy, env=env, verbose=verbose, policy_base=None, requires_vec_env=False) self.model_class = model_class self.env = env assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" self.wrapped_env = HERGoalEnvWrapper(env) - - self.model = self.model_class(policy, self.wrapped_env) - - + if isinstance(goal_selection_strategy, str): + assert goal_selection_strategy in KEY_TO_GOAL_STRATEGY.keys() + goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy] + self.replay_wrapper = functools.partial(HindsightExperienceReplayWrapper, n_sampled_goal=n_sampled_goal, + goal_selection_strategy=goal_selection_strategy, + wrapped_env=self.wrapped_env) + self.model = self.model_class(policy, self.wrapped_env, *args, **kwargs) def _get_pretrain_placeholders(self): raise NotImplementedError() - def setup_model(self): pass def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True): - with SetVerbosity(self.verbose): - self._setup_learn(seed) - - return self + return self.model.learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", + reset_num_timesteps=True, replay_wrapper=self.replay_wrapper) def predict(self, observation, state=None, mask=None, deterministic=False): pass diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index c5d865523a..f6e405e011 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -1,10 +1,8 @@ -from enum import Enum import copy +from enum import Enum import numpy as np -from stable_baselines.deepq.replay_buffer import ReplayBuffer - class GoalSelectionStrategy(Enum): FUTURE = 0 @@ -13,54 +11,54 @@ class GoalSelectionStrategy(Enum): RANDOM = 3 -class HindsightExperienceReplayBuffer(ReplayBuffer): - def __init__(self, size, n_sampled_goal, goal_selection_strategy, env): - """ +KEY_TO_GOAL_STRATEGY = { + 'future': GoalSelectionStrategy.FUTURE, + 'final': GoalSelectionStrategy.FINAL, + 'episode': GoalSelectionStrategy.EPISODE, + 'random': GoalSelectionStrategy.RANDOM +} + +class HindsightExperienceReplayWrapper(object): + def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapped_env): + """ Inspired by https://github.com/NervanaSystems/coach/. - :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old - memories are dropped. :param n_sampled_goal: The number of artificial transitions to generate for each actual transition :param goal_selection_strategy: The method that will be used for generating the goals for the hindsight transitions. Should be one of GoalSelectionStrategy - :param env: + :param wrapped_env: """ - super(HER, self).__init__(size) + super(HindsightExperienceReplayWrapper, self).__init__() self.n_sampled_goal = n_sampled_goal + assert isinstance(goal_selection_strategy, GoalSelectionStrategy) self.goal_selection_strategy = goal_selection_strategy - self.env = env + self.env = wrapped_env self.current_episode = [] - self.get_achieved_goal = None + self.replay_buffer = replay_buffer def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer - :param obs_t: (Any) the last observation + :param obs_t: (np.ndarray) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition - :param obs_tp1: (Any) the current observation + :param obs_tp1: (np.ndarray) the new observation :param done: (bool) is the episode done """ + assert self.replay_buffer is not None self.current_episode.append((obs_t, action, reward, obs_tp1, done)) if done: # Add transitions (and imagined ones) to buffer only when an episode is over self._store_episode() self.current_episode = [] - # def _encode_sample(self, idxes): - # obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] - # for i in idxes: - # data = self._storage[i] - # obs_t, action, reward, obs_tp1, done = data - # obses_t.append(np.array(obs_t, copy=False)) - # actions.append(np.array(action, copy=False)) - # rewards.append(reward) - # obses_tp1.append(np.array(obs_tp1, copy=False)) - # dones.append(done) - # return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) - # + def sample(self, *args, **kwargs): + return self.replay_buffer.sample(*args, **kwargs) + + def __len__(self): + return len(self.replay_buffer) def _sample_goal(self, episode_transitions, transition_idx): """ @@ -72,21 +70,23 @@ def _sample_goal(self, episode_transitions, transition_idx): """ if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # Sample a goal that was observed in the same episode after the current step - selected_obs = np.random.choice(episode_transitions[transition_idx + 1:]) + selected_idx = np.random.choice(np.arange(transition_idx + 1, len(episode_transitions))) + selected_transition = episode_transitions[selected_idx] elif self.goal_selection_strategy == GoalSelectionStrategy.FINAL: # The achieved goal at the end of the episode - selected_obs = episode_transitions[-1] + selected_transition = episode_transitions[-1] elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: # Random achieved goal during the episode - selected_obs = np.random.choice(episode_transitions) + selected_idx = np.random.choice(np.arange(len(episode_transitions))) + selected_transition = episode_transitions[selected_idx] elif self.goal_selection_strategy == GoalSelectionStrategy.RANDOM: # Random achieved goal from the entire replay buffer - # selected_obs = np.random.choice(self._storage) - raise NotImplementedError() + selected_idx = np.random.choice(np.arange(len(self.replay_buffer))) + selected_transition = self.replay_buffer._storage[selected_idx] else: raise ValueError("Invalid goal selection strategy," "please use one of {}".format(list(GoalSelectionStrategy))) - return self.get_achieved_goal(selected_obs) + return self.env.convert_obs_to_dict(selected_transition[0])['achieved_goal'] def _sample_goals(self, episode_transitions, transition_idx): """ @@ -109,10 +109,10 @@ def _store_episode(self): obs_t, action, reward, obs_tp1, done = transition # Add to the replay buffer - super().add(obs_t, action, reward, obs_tp1, done) + self.replay_buffer.add(obs_t, action, reward, obs_tp1, done) # We cannot sample a goal from the future in the last step of an episode if (transition_idx == len(last_episode_transitions) - 1 and - self.goal_selection_strategy == GoalSelectionStrategy.FUTURE): + self.goal_selection_strategy == GoalSelectionStrategy.FUTURE): break # Sampled n goals per transition, where n is `n_sampled_goal` @@ -131,11 +131,11 @@ def _store_episode(self): next_obs_dict['desired_goal'] = goal # update the reward and terminal signal according to the goal - reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal']) + reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal'], None) # Can we ensure that done = reward == 0 done = False obs, next_obs = map(self.env.convert_dict_to_obs, (obs_dict, next_obs_dict)) # Add to the replay buffer - super().add(obs, action, reward, next_obs, done) + self.replay_buffer.add(obs, action, reward, next_obs, done) diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index 5ef32e8545..ee24e10cd9 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -1,14 +1,16 @@ import numpy as np from gym import spaces + class HERGoalEnvWrapper(object): """docstring for HERGoalEnvWrapper.""" def __init__(self, env): super(HERGoalEnvWrapper, self).__init__() self.env = env + self.metadata = self.env.metadata self.action_space = env.action_space - self.spaces = env.observation_space.spaces.values() + self.spaces = list(env.observation_space.spaces.values()) # TODO: check that all spaces are of the same type # (current limiation of the wrapper) # TODO: check when dim > 1 @@ -26,9 +28,9 @@ def __init__(self, env): raise NotImplementedError() @staticmethod - def convert_dict_to_obs(obs_dict): + def convert_dict_to_obs(obs_dict): # Note: we should remove achieved goal from the observation ? - return np.concatenate([obs for obs in obs_dict.values()]) + return np.concatenate([obs for obs in obs_dict.values()]) def convert_obs_to_dict(self, observations): return { @@ -37,18 +39,21 @@ def convert_obs_to_dict(self, observations): 'desired_goal': observations[self.obs_dim + self.goal_dim:], } - def step(self, action): - obs, reward, done, info = self.env.step(action) - return self.convert_dict_to_obs(obs), reward, done, info + def step(self, action): + obs, reward, done, info = self.env.step(action) + return self.convert_dict_to_obs(obs), reward, done, info + + def seed(self, seed=None): + return self.env.seed(seed) - def seed(self, seed=None): - return self.env.seed(seed) + def reset(self): + return self.convert_dict_to_obs(self.env.reset()) - def reset(self): - return self.convert_dict_to_obs(self.env.reset()) + def compute_reward(self, achieved_goal, desired_goal, info): + return self.env.compute_reward(achieved_goal, desired_goal, info) - def render(self, mode='human'): - return self.env.render(mode) + def render(self, mode='human'): + return self.env.render(mode) - def close(self): - self.env.close() + def close(self): + return self.env.close() diff --git a/tests/test_her.py b/tests/test_her.py new file mode 100644 index 0000000000..cef487c1e4 --- /dev/null +++ b/tests/test_her.py @@ -0,0 +1,15 @@ +import pytest + +from stable_baselines import HER, DQN, SAC, DDPG +from stable_baselines.her import GoalSelectionStrategy +from stable_baselines.common.bit_flipping_env import BitFlippingEnv + +N_BITS = 10 + + +@pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) +def test_dqn_her(goal_selection_strategy): + env = BitFlippingEnv(N_BITS, continuous=False, max_steps=N_BITS) + model = HER('MlpPolicy', env, DQN, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, + prioritized_replay=False, verbose=1) + model.learn(5000) From 3e673302ec81838789a0679283effa02c32c55e1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 16 Apr 2019 22:38:48 +0200 Subject: [PATCH 04/36] Add support for SAC and DDPG --- docs/misc/changelog.rst | 2 + stable_baselines/ddpg/ddpg.py | 111 ++++++++++++++-------- stable_baselines/ddpg/main.py | 1 - stable_baselines/ddpg/memory.py | 130 -------------------------- stable_baselines/her/her.py | 14 ++- stable_baselines/her/replay_buffer.py | 49 ++++++---- stable_baselines/her/utils.py | 15 ++- stable_baselines/sac/sac.py | 5 +- 8 files changed, 133 insertions(+), 194 deletions(-) delete mode 100644 stable_baselines/ddpg/memory.py diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6206dc95d3..7f87d58e53 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -10,6 +10,8 @@ Pre-Release 2.5.1a0 (WIP) - doc update (fix example of result plotter + improve doc) - fixed logger issues when stdout lacks ``read`` function +- **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x) +- removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` Release 2.5.0 (2019-03-28) diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 82e3a1b48a..afc3352f6c 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -18,7 +18,7 @@ from stable_baselines.ddpg.policies import DDPGPolicy from stable_baselines.common.mpi_running_mean_std import RunningMeanStd from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger -from stable_baselines.ddpg.memory import Memory +from stable_baselines.deepq.replay_buffer import ReplayBuffer def normalize(tensor, stats): @@ -114,7 +114,7 @@ def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev, verb assert len(tf_util.get_globals_vars(actor)) == len(tf_util.get_globals_vars(perturbed_actor)) assert len([var for var in tf_util.get_trainable_vars(actor) if 'LayerNorm' not in var.name]) == \ - len([var for var in tf_util.get_trainable_vars(perturbed_actor) if 'LayerNorm' not in var.name]) + len([var for var in tf_util.get_trainable_vars(perturbed_actor) if 'LayerNorm' not in var.name]) updates = [] for var, perturbed_var in zip(tf_util.get_globals_vars(actor), tf_util.get_globals_vars(perturbed_actor)): @@ -140,7 +140,10 @@ class DDPG(OffPolicyRLModel): :param policy: (DDPGPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...) :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) :param gamma: (float) the discount factor - :param memory_policy: (Memory) the replay buffer (if None, default to baselines.ddpg.memory.Memory) + :param memory_policy: (ReplayBuffer) the replay buffer + (if None, default to baselines.deepq.replay_buffer.ReplayBuffer) + .. deprecated:: 2.6.0 + This parameter will be removed in a future version :param eval_env: (Gym Environment) the evaluation environment (can be None) :param nb_train_steps: (int) the number of training steps :param nb_rollout_steps: (int) the number of rollout steps @@ -163,7 +166,10 @@ class DDPG(OffPolicyRLModel): :param reward_scale: (float) the value the reward should be scaled by :param render: (bool) enable rendering of the environment :param render_eval: (bool) enable rendering of the evalution environment - :param memory_limit: (int) the max number of transitions to store + :param memory_limit: (int) the max number of transitions to store, size of the replay buffer + .. deprecated:: 2.6.0 + Use `buffer_size` instead. + :param buffer_size: (int) the max number of transitions to store, size of the replay buffer :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance @@ -177,17 +183,28 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n normalize_observations=False, tau=0.001, batch_size=128, param_noise_adaption_interval=50, normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0., return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1., - render=False, render_eval=False, memory_limit=50000, verbose=0, tensorboard_log=None, - _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): + render=False, render_eval=False, memory_limit=None, buffer_size=50000, + verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, + full_tensorboard_log=False): - # TODO: replay_buffer refactoring - super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy, + super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, + verbose=verbose, policy_base=DDPGPolicy, requires_vec_env=False, policy_kwargs=policy_kwargs) # Parameters. self.gamma = gamma self.tau = tau - self.memory_policy = memory_policy or Memory + + # TODO: remove this param in v3.x.x + if memory_policy is not None: + warnings.warn("memory_policy will be removed in a future version (v3.x.x) " + "it is now ignored and replaced with ReplayBuffer", DeprecationWarning) + + if memory_limit is not None: + warnings.warn("memory_limit will be removed in a future version (v3.x.x) " + "use buffer_size instead", DeprecationWarning) + buffer_size = memory_limit + self.normalize_observations = normalize_observations self.normalize_returns = normalize_returns self.action_noise = action_noise @@ -209,13 +226,14 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n self.nb_train_steps = nb_train_steps self.nb_rollout_steps = nb_rollout_steps self.memory_limit = memory_limit + self.buffer_size = buffer_size self.tensorboard_log = tensorboard_log self.full_tensorboard_log = full_tensorboard_log # init self.graph = None self.stats_sample = None - self.memory = None + self.replay_buffer = None self.policy_tf = None self.target_init_updates = None self.target_soft_updates = None @@ -265,6 +283,8 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n self.tb_seen_steps = None self.target_params = None + self.obs_rms_params = None + self.ret_rms_params = None if _init_setup_model: self.setup_model() @@ -287,8 +307,7 @@ def setup_model(self): with self.graph.as_default(): self.sess = tf_util.single_threaded_session(graph=self.graph) - self.memory = self.memory_policy(limit=self.memory_limit, action_shape=self.action_space.shape, - observation_shape=self.observation_space.shape) + self.replay_buffer = ReplayBuffer(self.buffer_size) with tf.variable_scope("input", reuse=False): # Observation normalization. @@ -401,9 +420,9 @@ def setup_model(self): self.params = find_trainable_variables("model") self.target_params = find_trainable_variables("target") self.obs_rms_params = [var for var in tf.global_variables() - if "obs_rms" in var.name] + if "obs_rms" in var.name] self.ret_rms_params = [var for var in tf.global_variables() - if "ret_rms" in var.name] + if "ret_rms" in var.name] with self.sess.as_default(): self._initialize(self.sess) @@ -596,10 +615,10 @@ def _store_transition(self, obs0, action, reward, obs1, terminal1): :param action: ([float]) the action :param reward: (float] the reward :param obs1: ([float] or [int]) the current observation - :param terminal1: (bool) is the episode done + :param terminal1: (bool) Whether the episode is over """ reward *= self.reward_scale - self.memory.append(obs0, action, reward, obs1, terminal1) + self.replay_buffer.add(obs0, action, reward, obs1, float(terminal1)) if self.normalize_observations: self.obs_rms.update(np.array([obs0])) @@ -613,14 +632,17 @@ def _train_step(self, step, writer, log=False): :return: (float, float) critic loss, actor loss """ # Get a batch - batch = self.memory.sample(batch_size=self.batch_size) + obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size) + # Reshape to match previous behavior and placeholder shape + rewards = rewards.reshape(-1, 1) + terminals1 = terminals1.reshape(-1, 1) if self.normalize_returns and self.enable_popart: old_mean, old_std, target_q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_q], feed_dict={ - self.obs_target: batch['obs1'], - self.rewards: batch['rewards'], - self.terminals1: batch['terminals1'].astype('float32') + self.obs_target: obs1, + self.rewards: rewards, + self.terminals1: terminals1 }) self.ret_rms.update(target_q.flatten()) self.sess.run(self.renormalize_q_outputs_op, feed_dict={ @@ -630,18 +652,18 @@ def _train_step(self, step, writer, log=False): else: target_q = self.sess.run(self.target_q, feed_dict={ - self.obs_target: batch['obs1'], - self.rewards: batch['rewards'], - self.terminals1: batch['terminals1'].astype('float32') + self.obs_target: obs1, + self.rewards: rewards, + self.terminals1: terminals1 }) # Get all gradients and perform a synced update. ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss] td_map = { - self.obs_train: batch['obs0'], - self.actions: batch['actions'], - self.action_train_ph: batch['actions'], - self.rewards: batch['rewards'], + self.obs_train: obs0, + self.actions: actions, + self.action_train_ph: actions, + self.rewards: rewards, self.critic_target: target_q, self.param_noise_stddev: 0 if self.param_noise is None else self.param_noise.current_stddev } @@ -695,7 +717,14 @@ def _get_stats(self): if self.stats_sample is None: # Get a sample and keep that fixed for all further computations. # This allows us to estimate the change in value for the same set of inputs. - self.stats_sample = self.memory.sample(batch_size=self.batch_size) + obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size) + self.stats_sample = { + 'obs0': obs0, + 'actions': actions, + 'rewards': rewards, + 'obs1': obs1, + 'terminals1': terminals1 + } feed_dict = { self.actions: self.stats_sample['actions'] @@ -730,12 +759,12 @@ def _adapt_param_noise(self): return 0. # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation. - batch = self.memory.sample(batch_size=self.batch_size) + obs0, *_ = self.replay_buffer.sample(batch_size=self.batch_size) self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={ self.param_noise_stddev: self.param_noise.current_stddev, }) distance = self.sess.run(self.adaptive_policy_distance, feed_dict={ - self.obs_adapt_noise: batch['obs0'], self.obs_train: batch['obs0'], + self.obs_adapt_noise: obs0, self.obs_train: obs0, self.param_noise_stddev: self.param_noise.current_stddev, }) @@ -755,10 +784,13 @@ def _reset(self): }) def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DDPG", - reset_num_timesteps=True): + reset_num_timesteps=True, replay_wrapper=None): new_tb_log = self._init_num_timesteps(reset_num_timesteps) + if replay_wrapper is not None: + self.replay_buffer = replay_wrapper(self.replay_buffer) + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ as writer: self._setup_learn(seed) @@ -862,7 +894,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ epoch_adaptive_distances = [] for t_train in range(self.nb_train_steps): # Adapt param noise, if necessary. - if self.memory.nb_entries >= self.batch_size and \ + if len(self.replay_buffer) >= self.batch_size and \ t_train % self.param_noise_adaption_interval == 0: distance = self._adapt_param_noise() epoch_adaptive_distances.append(distance) @@ -1013,8 +1045,8 @@ def save(self, save_path): "clip_norm": self.clip_norm, "reward_scale": self.reward_scale, "memory_limit": self.memory_limit, + "buffer_size": self.buffer_size, "policy": self.policy, - "memory_policy": self.memory_policy, "n_envs": self.n_envs, "_vectorize_action": self._vectorize_action, "policy_kwargs": self.policy_kwargs @@ -1026,14 +1058,13 @@ def save(self, save_path): norm_ret_params = self.sess.run(self.ret_rms_params) params_to_save = params \ - + target_params \ - + norm_obs_params \ - + norm_ret_params + + target_params \ + + norm_obs_params \ + + norm_ret_params self._save_to_file(save_path, data=data, params=params_to_save) - @classmethod def load(cls, load_path, env=None, **kwargs): data, params = cls._load_from_file(load_path) @@ -1051,9 +1082,9 @@ def load(cls, load_path, env=None, **kwargs): restores = [] params_to_load = model.params \ - + model.target_params \ - + model.obs_rms_params \ - + model.ret_rms_params + + model.target_params \ + + model.obs_rms_params \ + + model.ret_rms_params for param, loaded_p in zip(params_to_load, params): restores.append(param.assign(loaded_p)) model.sess.run(restores) diff --git a/stable_baselines/ddpg/main.py b/stable_baselines/ddpg/main.py index 106efb03e1..466fd7a7f2 100644 --- a/stable_baselines/ddpg/main.py +++ b/stable_baselines/ddpg/main.py @@ -11,7 +11,6 @@ from stable_baselines.common.misc_util import set_global_seeds, boolean_flag from stable_baselines.ddpg.policies import MlpPolicy, LnMlpPolicy from stable_baselines.ddpg import DDPG -from stable_baselines.ddpg.memory import Memory from stable_baselines.ddpg.noise import AdaptiveParamNoiseSpec, OrnsteinUhlenbeckActionNoise, NormalActionNoise diff --git a/stable_baselines/ddpg/memory.py b/stable_baselines/ddpg/memory.py deleted file mode 100644 index ee5e94fdff..0000000000 --- a/stable_baselines/ddpg/memory.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np - - -class RingBuffer(object): - def __init__(self, maxlen, shape, dtype='float32'): - """ - A buffer object, when full restarts at the initial position - - :param maxlen: (int) the max number of numpy objects to store - :param shape: (tuple) the shape of the numpy objects you want to store - :param dtype: (str) the name of the type of the numpy object you want to store - """ - self.maxlen = maxlen - self.start = 0 - self.length = 0 - self.data = np.zeros((maxlen,) + shape).astype(dtype) - - def __len__(self): - return self.length - - def __getitem__(self, idx): - if idx < 0 or idx >= self.length: - raise KeyError() - return self.data[(self.start + idx) % self.maxlen] - - def get_batch(self, idxs): - """ - get the value at the indexes - - :param idxs: (int or numpy int) the indexes - :return: (np.ndarray) the stored information in the buffer at the asked positions - """ - return self.data[(self.start + idxs) % self.maxlen] - - def append(self, var): - """ - Append an object to the buffer - - :param var: (np.ndarray) the object you wish to add - """ - if self.length < self.maxlen: - # We have space, simply increase the length. - self.length += 1 - elif self.length == self.maxlen: - # No space, "remove" the first item. - self.start = (self.start + 1) % self.maxlen - else: - # This should never happen. - raise RuntimeError() - self.data[(self.start + self.length - 1) % self.maxlen] = var - - -def array_min2d(arr): - """ - cast to np.ndarray, and make sure it is of 2 dim - - :param arr: ([Any]) the array to clean - :return: (np.ndarray) the cleaned array - """ - arr = np.array(arr) - if arr.ndim >= 2: - return arr - return arr.reshape(-1, 1) - - -class Memory(object): - def __init__(self, limit, action_shape, observation_shape): - """ - The replay buffer object - - :param limit: (int) the max number of transitions to store - :param action_shape: (tuple) the action shape - :param observation_shape: (tuple) the observation shape - """ - self.limit = limit - - self.observations0 = RingBuffer(limit, shape=observation_shape) - self.actions = RingBuffer(limit, shape=action_shape) - self.rewards = RingBuffer(limit, shape=(1,)) - self.terminals1 = RingBuffer(limit, shape=(1,)) - self.observations1 = RingBuffer(limit, shape=observation_shape) - - def sample(self, batch_size): - """ - sample a random batch from the buffer - - :param batch_size: (int) the number of element to sample for the batch - :return: (dict) the sampled batch - """ - # Draw such that we always have a proceeding element. - batch_idxs = np.random.randint(low=1, high=self.nb_entries - 1, size=batch_size) - - obs0_batch = self.observations0.get_batch(batch_idxs) - obs1_batch = self.observations1.get_batch(batch_idxs) - action_batch = self.actions.get_batch(batch_idxs) - reward_batch = self.rewards.get_batch(batch_idxs) - terminal1_batch = self.terminals1.get_batch(batch_idxs) - - result = { - 'obs0': array_min2d(obs0_batch), - 'obs1': array_min2d(obs1_batch), - 'rewards': array_min2d(reward_batch), - 'actions': array_min2d(action_batch), - 'terminals1': array_min2d(terminal1_batch), - } - return result - - def append(self, obs0, action, reward, obs1, terminal1, training=True): - """ - Append a transition to the buffer - - :param obs0: ([float] or [int]) the last observation - :param action: ([float]) the action - :param reward: (float] the reward - :param obs1: ([float] or [int]) the current observation - :param terminal1: (bool) is the episode done - :param training: (bool) is the RL model training or not - """ - if not training: - return - - self.observations0.append(obs0) - self.actions.append(action) - self.rewards.append(reward) - self.observations1.append(obs1) - self.terminals1.append(terminal1) - - @property - def nb_entries(self): - return len(self.observations0) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 89358b4d20..d45c76f68b 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -9,7 +9,14 @@ class HER(BaseRLModel): """ - Hindsight Experience replay. + Hindsight Experience Replay (HER) https://arxiv.org/abs/1707.01495 + + :param policy: (BasePolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...) + :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) + :param model_class: (OffPolicyRLModel) The off policy RL model to apply Hindsight Experience Replay + currently supported: DQN, DDPG, SAC + :param n_sampled_goal: (int) + :param goal_selection_strategy: (GoalSelectionStrategy or str) """ def __init__(self, policy, env, model_class, n_sampled_goal=4, @@ -18,11 +25,14 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, self.model_class = model_class self.env = env - assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" + # TODO: check for TimeLimit wrapper too + # assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" self.wrapped_env = HERGoalEnvWrapper(env) + if isinstance(goal_selection_strategy, str): assert goal_selection_strategy in KEY_TO_GOAL_STRATEGY.keys() goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy] + self.replay_wrapper = functools.partial(HindsightExperienceReplayWrapper, n_sampled_goal=n_sampled_goal, goal_selection_strategy=goal_selection_strategy, wrapped_env=self.wrapped_env) diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index f6e405e011..565cc28849 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -11,6 +11,7 @@ class GoalSelectionStrategy(Enum): RANDOM = 3 +# For convenience KEY_TO_GOAL_STRATEGY = { 'future': GoalSelectionStrategy.FUTURE, 'final': GoalSelectionStrategy.FINAL, @@ -20,23 +21,33 @@ class GoalSelectionStrategy(Enum): class HindsightExperienceReplayWrapper(object): - def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapped_env): - """ - Inspired by https://github.com/NervanaSystems/coach/. + """ + Wrapper around a replay buffer in order to use HER. + This implementation is close to the one found in https://github.com/NervanaSystems/coach/. - :param n_sampled_goal: The number of artificial transitions to generate for each actual transition - :param goal_selection_strategy: The method that will be used for generating the goals for the - hindsight transitions. Should be one of GoalSelectionStrategy - :param wrapped_env: - """ + :param replay_buffer: (ReplayBuffer) + :param n_sampled_goal: (int) The number of artificial transitions to generate for each actual transition + :param goal_selection_strategy: (GoalSelectionStrategy) The method that will be used to generate + the goals for the artificial transitions. + :param wrapped_env: (HERGoalEnvWrapper) + """ + + def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapped_env): super(HindsightExperienceReplayWrapper, self).__init__() self.n_sampled_goal = n_sampled_goal - assert isinstance(goal_selection_strategy, GoalSelectionStrategy) + + assert isinstance(goal_selection_strategy, GoalSelectionStrategy), "Invalid goal selection strategy," \ + "please use one of {}".format( + list(GoalSelectionStrategy)) + self.goal_selection_strategy = goal_selection_strategy self.env = wrapped_env self.current_episode = [] self.replay_buffer = replay_buffer + def append(self, obs_t, action, reward, obs_tp1, done): + return self.add(obs_t, action, reward, obs_tp1, done) + def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer @@ -48,6 +59,7 @@ def add(self, obs_t, action, reward, obs_tp1, done): :param done: (bool) is the episode done """ assert self.replay_buffer is not None + # Update current episode buffer self.current_episode.append((obs_t, action, reward, obs_tp1, done)) if done: # Add transitions (and imagined ones) to buffer only when an episode is over @@ -62,10 +74,10 @@ def __len__(self): def _sample_goal(self, episode_transitions, transition_idx): """ - Sample an achieved goal according to the sampling method. + Sample an achieved goal according to the sampling strategy. - :param episode_transitions: a list of all the transitions in the current episode - :param transition_idx: the transition to start sampling from + :param episode_transitions: ([tuple]) a list of all the transitions in the current episode + :param transition_idx: (int) the transition to start sampling from :return: (np.ndarray) an achieved goal """ if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: @@ -92,9 +104,9 @@ def _sample_goals(self, episode_transitions, transition_idx): """ Sample a batch of achieved goal according to the sampling strategy. - :param episode_transitions: () a list of all the transitions in the current episode - :param transition_idx: the transition to start sampling from - :return: a goal corresponding to the sampled obs + :param episode_transitions: ([tuple]) a list of all the transitions in the current episode + :param transition_idx: (int) the transition to start sampling from + :return: (np.ndarray) a goal corresponding to the sampled obs """ return [ self._sample_goal(episode_transitions, transition_idx) @@ -102,9 +114,14 @@ def _sample_goals(self, episode_transitions, transition_idx): ] def _store_episode(self): + """ + Sample artificial goals and store transition of the current + episode in the replay buffer. + This method is called only after each end of episode. + """ last_episode_transitions = copy.deepcopy(self.current_episode) - # for each transition in the last episode, create a set of hindsight transitions + # For each transition in the last episode, create a set of hindsight transitions for transition_idx, transition in enumerate(last_episode_transitions): obs_t, action, reward, obs_tp1, done = transition diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index ee24e10cd9..dcb19639a0 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -14,16 +14,23 @@ def __init__(self, env): # TODO: check that all spaces are of the same type # (current limiation of the wrapper) # TODO: check when dim > 1 + + goal_space_shape = env.observation_space.spaces['achieved_goal'].shape self.obs_dim = env.observation_space.spaces['observation'].shape[0] - self.goal_dim = env.observation_space.spaces['achieved_goal'].shape[0] + self.goal_dim = goal_space_shape[0] total_dim = self.obs_dim + 2 * self.goal_dim + if len(goal_space_shape) == 2: + assert goal_space_shape[1] == 1 + else: + assert len(goal_space_shape) == 1 + if isinstance(self.spaces[0], spaces.MultiBinary): self.observation_space = spaces.MultiBinary(total_dim) elif isinstance(self.spaces[0], spaces.Box): - # total_dim = np.sum([space.shape[0] for space in self.spaces]) - # self.observation_space = spaces.Box(-np.inf, np.inf, shape=(total_dim, ), dtype=np.float32) - raise NotImplementedError() + lows = np.concatenate([space.low for space in self.spaces]) + highs = np.concatenate([space.high for space in self.spaces]) + self.observation_space = spaces.Box(lows, highs, dtype=np.float32) else: raise NotImplementedError() diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index aef03eaa66..aa9404d11a 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -352,10 +352,13 @@ def _train_step(self, step, writer, learning_rate): return policy_loss, qf1_loss, qf2_loss, value_loss, entropy def learn(self, total_timesteps, callback=None, seed=None, - log_interval=4, tb_log_name="SAC", reset_num_timesteps=True): + log_interval=4, tb_log_name="SAC", reset_num_timesteps=True, replay_wrapper=None): new_tb_log = self._init_num_timesteps(reset_num_timesteps) + if replay_wrapper is not None: + self.replay_buffer = replay_wrapper(self.replay_buffer) + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ as writer: From dab564715d99bd08404493bc3e0be35fea6b6119 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 20 Apr 2019 15:09:19 +0200 Subject: [PATCH 05/36] Add tests for SAC and DDPG + HER + add comments --- stable_baselines/her/her.py | 34 +++++++++++++++++++++++---- stable_baselines/her/replay_buffer.py | 24 +++++++++++++++---- stable_baselines/her/utils.py | 17 +++++++++++++- tests/test_her.py | 16 +++++++++++-- 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index d45c76f68b..7d14b6a62b 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -23,6 +23,7 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, goal_selection_strategy='future', *args, **kwargs): # super().__init__(policy=policy, env=env, verbose=verbose, policy_base=None, requires_vec_env=False) + # TODO: check if the env is not already wrapped self.model_class = model_class self.env = env # TODO: check for TimeLimit wrapper too @@ -38,8 +39,32 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, wrapped_env=self.wrapped_env) self.model = self.model_class(policy, self.wrapped_env, *args, **kwargs) + def set_env(self, env): + self.env = env + self.wrapped_env = HERGoalEnvWrapper(env) + self.model.set_env(self.wrapped_env) + + def get_env(self): + return self.wrapped_env + + def __getattr__(self, attr): + """ + Wrap the RL model. + :param attr: (str) + :return: (Any) + """ + if attr in self.__dict__: + return getattr(self, attr) + return getattr(self.model, attr) + + def __set_attr__(self, attr, value): + if attr in self.__dict__: + setattr(self, attr, value) + else: + set_attr(self.model, attr, value) + def _get_pretrain_placeholders(self): - raise NotImplementedError() + return self.model._get_pretrain_placeholders() def setup_model(self): pass @@ -49,11 +74,12 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ return self.model.learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True, replay_wrapper=self.replay_wrapper) - def predict(self, observation, state=None, mask=None, deterministic=False): - pass + def predict(self, observation, state=None, mask=None, deterministic=True): + # TODO: assert the type of observation + return self.model.predict(observation, state, mask, deterministic) def action_probability(self, observation, state=None, mask=None, actions=None): - pass + return self.model.action_probability(observation, state, mask, actions) def save(self, save_path): pass diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index 565cc28849..ee63b7ab7c 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -5,13 +5,26 @@ class GoalSelectionStrategy(Enum): + """ + The strategies for selecting new goals when + creating artificial transitions. + """ + # Select a goal that was achieved + # after the current step, in the same episode FUTURE = 0 + # Select the goal that was achieved + # at the end of the episode FINAL = 1 + # Select a goal that was achieved in the episode EPISODE = 2 + # Select a goal that was achieved + # at some point in the training procedure + # (and that is present in the replay buffer) RANDOM = 3 # For convenience +# that way, we can use string to select a strategy KEY_TO_GOAL_STRATEGY = { 'future': GoalSelectionStrategy.FUTURE, 'final': GoalSelectionStrategy.FINAL, @@ -29,7 +42,8 @@ class HindsightExperienceReplayWrapper(object): :param n_sampled_goal: (int) The number of artificial transitions to generate for each actual transition :param goal_selection_strategy: (GoalSelectionStrategy) The method that will be used to generate the goals for the artificial transitions. - :param wrapped_env: (HERGoalEnvWrapper) + :param wrapped_env: (HERGoalEnvWrapper) the GoalEnv wrapped using HERGoalEnvWrapper, + that enables to convert observation to dict, and vice versa """ def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapped_env): @@ -72,7 +86,7 @@ def sample(self, *args, **kwargs): def __len__(self): return len(self.replay_buffer) - def _sample_goal(self, episode_transitions, transition_idx): + def _sample_achieved_goal(self, episode_transitions, transition_idx): """ Sample an achieved goal according to the sampling strategy. @@ -100,7 +114,7 @@ def _sample_goal(self, episode_transitions, transition_idx): "please use one of {}".format(list(GoalSelectionStrategy))) return self.env.convert_obs_to_dict(selected_transition[0])['achieved_goal'] - def _sample_goals(self, episode_transitions, transition_idx): + def _sample_achieved_goals(self, episode_transitions, transition_idx): """ Sample a batch of achieved goal according to the sampling strategy. @@ -109,7 +123,7 @@ def _sample_goals(self, episode_transitions, transition_idx): :return: (np.ndarray) a goal corresponding to the sampled obs """ return [ - self._sample_goal(episode_transitions, transition_idx) + self._sample_achieved_goal(episode_transitions, transition_idx) for _ in range(self.n_sampled_goal) ] @@ -134,7 +148,7 @@ def _store_episode(self): # Sampled n goals per transition, where n is `n_sampled_goal` # this is called k in the paper - sampled_goals = self._sample_goals(last_episode_transitions, transition_idx) + sampled_goals = self._sample_achieved_goals(last_episode_transitions, transition_idx) # For each sampled goals, store a new transition for goal in sampled_goals: obs, action, reward, next_obs, done = copy.deepcopy(transition) diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index dcb19639a0..38b1712123 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -3,7 +3,12 @@ class HERGoalEnvWrapper(object): - """docstring for HERGoalEnvWrapper.""" + """ + A wrapper that allow to use dict env (coming from GoalEnv) with + the RL algorithms. + + :param env: (gym.GoalEnv) + """ def __init__(self, env): super(HERGoalEnvWrapper, self).__init__() @@ -31,15 +36,25 @@ def __init__(self, env): lows = np.concatenate([space.low for space in self.spaces]) highs = np.concatenate([space.high for space in self.spaces]) self.observation_space = spaces.Box(lows, highs, dtype=np.float32) + elif isinstance(self.spaces[0], spaces.Discrete): + pass else: raise NotImplementedError() @staticmethod def convert_dict_to_obs(obs_dict): + """ + :param obs_dict: (dict) + :return: (np.ndarray) + """ # Note: we should remove achieved goal from the observation ? return np.concatenate([obs for obs in obs_dict.values()]) def convert_obs_to_dict(self, observations): + """ + :param observations: (np.ndarray) + :return: (dict) + """ return { 'observation': observations[:self.obs_dim], 'achieved_goal': observations[self.obs_dim:self.obs_dim + self.goal_dim], diff --git a/tests/test_her.py b/tests/test_her.py index cef487c1e4..8be095e7ec 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -1,5 +1,7 @@ import pytest +import numpy as np + from stable_baselines import HER, DQN, SAC, DDPG from stable_baselines.her import GoalSelectionStrategy from stable_baselines.common.bit_flipping_env import BitFlippingEnv @@ -11,5 +13,15 @@ def test_dqn_her(goal_selection_strategy): env = BitFlippingEnv(N_BITS, continuous=False, max_steps=N_BITS) model = HER('MlpPolicy', env, DQN, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, - prioritized_replay=False, verbose=1) - model.learn(5000) + prioritized_replay=False, verbose=0) + model.learn(1000) + + +@pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) +@pytest.mark.parametrize('model_class', [SAC, DDPG]) +def test_continuous_her(model_class, goal_selection_strategy): + env = BitFlippingEnv(N_BITS, continuous=True, max_steps=N_BITS) + + model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, + verbose=0) + model.learn(1000) From 9e42f1e9458224247f505cf09f2e241393081a83 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 20 Apr 2019 16:39:25 +0200 Subject: [PATCH 06/36] Bug fix + add comments --- stable_baselines/ddpg/main.py | 4 +-- stable_baselines/her/her.py | 24 ++++++++++++++- stable_baselines/her/replay_buffer.py | 42 +++++++++++++++------------ tests/test_her.py | 17 ++--------- 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/stable_baselines/ddpg/main.py b/stable_baselines/ddpg/main.py index 466fd7a7f2..3e1232788c 100644 --- a/stable_baselines/ddpg/main.py +++ b/stable_baselines/ddpg/main.py @@ -86,8 +86,8 @@ def run(env_id, seed, noise_type, layer_norm, evaluation, **kwargs): num_timesteps = kwargs['num_timesteps'] del kwargs['num_timesteps'] - model = DDPG(policy=policy, env=env, memory_policy=Memory, eval_env=eval_env, param_noise=param_noise, - action_noise=action_noise, memory_limit=int(1e6), verbose=2, **kwargs) + model = DDPG(policy=policy, env=env, eval_env=eval_env, param_noise=param_noise, + action_noise=action_noise, buffer_size=int(1e6), verbose=2, **kwargs) model.learn(total_timesteps=num_timesteps) env.close() if eval_env is not None: diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 7d14b6a62b..173d46571a 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -27,6 +27,7 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, self.model_class = model_class self.env = env # TODO: check for TimeLimit wrapper too + # TODO: support VecEnv # assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" self.wrapped_env = HERGoalEnvWrapper(env) @@ -81,9 +82,30 @@ def predict(self, observation, state=None, mask=None, deterministic=True): def action_probability(self, observation, state=None, mask=None, actions=None): return self.model.action_probability(observation, state, mask, actions) + # def _save_to_file(self, save_path, data=None, params=None): + # # HACK to save the replay wrapper + # # or better to save only the replay strategy and its params? + # # it will not work with VecEnv + # data['replay_wrapper'] = self.replay_wrapper + # data['model_class'] = self.model_class + # super()._save_to_file(save_path, data, params) + def save(self, save_path): - pass + # Is there something more to save? (the replay wrapper?) + self.model.save(save_path) @classmethod def load(cls, load_path, env=None, **kwargs): pass + # data, params = cls._load_from_file(load_path) + # + # if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + # raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + # "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + # kwargs['policy_kwargs'])) + # + # model = cls(policy=data["policy"], env=env, model_class=data['model_class'], _init_setup_model=False) + # model.__dict__.update(data) + # model.__dict__.update(kwargs) + # model.model = data['model_class'].load(load_path, model.get_env()) + # return model diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index ee63b7ab7c..d603b9eb2a 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -48,15 +48,16 @@ class HindsightExperienceReplayWrapper(object): def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapped_env): super(HindsightExperienceReplayWrapper, self).__init__() - self.n_sampled_goal = n_sampled_goal assert isinstance(goal_selection_strategy, GoalSelectionStrategy), "Invalid goal selection strategy," \ "please use one of {}".format( list(GoalSelectionStrategy)) + self.n_sampled_goal = n_sampled_goal self.goal_selection_strategy = goal_selection_strategy self.env = wrapped_env - self.current_episode = [] + # Buffer for storing transitions of the current episode + self.episode_transitions = [] self.replay_buffer = replay_buffer def append(self, obs_t, action, reward, obs_tp1, done): @@ -74,11 +75,12 @@ def add(self, obs_t, action, reward, obs_tp1, done): """ assert self.replay_buffer is not None # Update current episode buffer - self.current_episode.append((obs_t, action, reward, obs_tp1, done)) + self.episode_transitions.append((obs_t, action, reward, obs_tp1, done)) if done: # Add transitions (and imagined ones) to buffer only when an episode is over self._store_episode() - self.current_episode = [] + # Reset episode buffer + self.episode_transitions = [] def sample(self, *args, **kwargs): return self.replay_buffer.sample(*args, **kwargs) @@ -99,7 +101,7 @@ def _sample_achieved_goal(self, episode_transitions, transition_idx): selected_idx = np.random.choice(np.arange(transition_idx + 1, len(episode_transitions))) selected_transition = episode_transitions[selected_idx] elif self.goal_selection_strategy == GoalSelectionStrategy.FINAL: - # The achieved goal at the end of the episode + # Choose the goal achieved at the end of the episode selected_transition = episode_transitions[-1] elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: # Random achieved goal during the episode @@ -116,11 +118,11 @@ def _sample_achieved_goal(self, episode_transitions, transition_idx): def _sample_achieved_goals(self, episode_transitions, transition_idx): """ - Sample a batch of achieved goal according to the sampling strategy. + Sample a batch of achieved goals according to the sampling strategy. - :param episode_transitions: ([tuple]) a list of all the transitions in the current episode + :param episode_transitions: ([tuple]) list of the transitions in the current episode :param transition_idx: (int) the transition to start sampling from - :return: (np.ndarray) a goal corresponding to the sampled obs + :return: (np.ndarray) an achieved goal """ return [ self._sample_achieved_goal(episode_transitions, transition_idx) @@ -133,40 +135,44 @@ def _store_episode(self): episode in the replay buffer. This method is called only after each end of episode. """ - last_episode_transitions = copy.deepcopy(self.current_episode) + # NOTE: is deepcopy really needed here? + # last_episode_transitions = copy.deepcopy(self.episode_transitions) - # For each transition in the last episode, create a set of hindsight transitions - for transition_idx, transition in enumerate(last_episode_transitions): + # For each transition in the last episode, + # create a set of artificial transitions + for transition_idx, transition in enumerate(self.episode_transitions): obs_t, action, reward, obs_tp1, done = transition # Add to the replay buffer self.replay_buffer.add(obs_t, action, reward, obs_tp1, done) + # We cannot sample a goal from the future in the last step of an episode - if (transition_idx == len(last_episode_transitions) - 1 and + if (transition_idx == len(self.episode_transitions) - 1 and self.goal_selection_strategy == GoalSelectionStrategy.FUTURE): break # Sampled n goals per transition, where n is `n_sampled_goal` # this is called k in the paper - sampled_goals = self._sample_achieved_goals(last_episode_transitions, transition_idx) + sampled_goals = self._sample_achieved_goals(self.episode_transitions, transition_idx) # For each sampled goals, store a new transition for goal in sampled_goals: + # Copy transition to avoid modifying the original one obs, action, reward, next_obs, done = copy.deepcopy(transition) - # Convert concatenated obs to dict - # so we can update the goals + # Convert concatenated obs to dict, so we can update the goals obs_dict, next_obs_dict = map(self.env.convert_obs_to_dict, (obs, next_obs)) - # update the desired goal in the transition + # Update the desired goal in the transition obs_dict['desired_goal'] = goal next_obs_dict['desired_goal'] = goal - # update the reward and terminal signal according to the goal + # Update the reward according to the new desired goal reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal'], None) # Can we ensure that done = reward == 0 done = False + # Transform back to ndarrays obs, next_obs = map(self.env.convert_dict_to_obs, (obs_dict, next_obs_dict)) - # Add to the replay buffer + # Add artificial transition to the replay buffer self.replay_buffer.add(obs, action, reward, next_obs, done) diff --git a/tests/test_her.py b/tests/test_her.py index 8be095e7ec..91363134d8 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -1,26 +1,15 @@ import pytest -import numpy as np - from stable_baselines import HER, DQN, SAC, DDPG from stable_baselines.her import GoalSelectionStrategy from stable_baselines.common.bit_flipping_env import BitFlippingEnv N_BITS = 10 - -@pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) -def test_dqn_her(goal_selection_strategy): - env = BitFlippingEnv(N_BITS, continuous=False, max_steps=N_BITS) - model = HER('MlpPolicy', env, DQN, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, - prioritized_replay=False, verbose=0) - model.learn(1000) - - @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) -@pytest.mark.parametrize('model_class', [SAC, DDPG]) -def test_continuous_her(model_class, goal_selection_strategy): - env = BitFlippingEnv(N_BITS, continuous=True, max_steps=N_BITS) +@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) +def test_her(model_class, goal_selection_strategy): + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, verbose=0) From 63ffc83834850915be6f584415fa4ad451347ec5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 20 Apr 2019 23:37:28 +0200 Subject: [PATCH 07/36] Add action noise for SAC --- docs/misc/changelog.rst | 1 + stable_baselines/sac/sac.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7f87d58e53..bb1efcb0b6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,7 @@ Pre-Release 2.5.1a0 (WIP) - fixed logger issues when stdout lacks ``read`` function - **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x) - removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` +- add ``action_noise`` param for SAC, it helps exploration for problem with deceptive reward Release 2.5.0 (2019-03-28) diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index aa9404d11a..41c87c0b5d 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -53,6 +53,8 @@ class SAC(OffPolicyRLModel): :param target_update_interval: (int) update the target network every `target_network_update_freq` steps. :param gradient_steps: (int) How many gradient update after each step :param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto') + :param action_noise: (ActionNoise) the action noise type (None by default), this can help + for hard exploration problem. Cf DDPG for the different action noise type. :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance @@ -64,7 +66,8 @@ class SAC(OffPolicyRLModel): def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=50000, learning_starts=100, train_freq=1, batch_size=64, tau=0.005, ent_coef='auto', target_update_interval=1, - gradient_steps=1, target_entropy='auto', verbose=0, tensorboard_log=None, + gradient_steps=1, target_entropy='auto', action_noise=None, + verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(SAC, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, @@ -86,6 +89,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=5000 self.target_update_interval = target_update_interval self.gradient_steps = gradient_steps self.gamma = gamma + self.action_noise = action_noise self.value_fn = None self.graph = None @@ -371,6 +375,8 @@ def learn(self, total_timesteps, callback=None, seed=None, start_time = time.time() episode_rewards = [0.0] + if self.action_noise is not None: + self.action_noise.reset() obs = self.env.reset() self.episode_reward = np.zeros((1,)) ep_info_buf = deque(maxlen=100) @@ -393,6 +399,10 @@ def learn(self, total_timesteps, callback=None, seed=None, rescaled_action = action else: action = self.policy_tf.step(obs[None], deterministic=False).flatten() + # Add noise to the action (improve exploration, + # not needed in general) + if self.action_noise is not None: + action = np.clip(action + self.action_noise(), -1, 1) # Rescale from [-1, 1] to the correct bounds rescaled_action = action * np.abs(self.action_space.low) @@ -438,6 +448,8 @@ def learn(self, total_timesteps, callback=None, seed=None, episode_rewards[-1] += reward if done: + if self.action_noise is not None: + self.action_noise.reset() if not isinstance(self.env, VecEnv): obs = self.env.reset() episode_rewards.append(0.0) @@ -518,6 +530,7 @@ def save(self, save_path): "action_space": self.action_space, "policy": self.policy, "n_envs": self.n_envs, + "action_noise": self.action_noise, "_vectorize_action": self._vectorize_action, "policy_kwargs": self.policy_kwargs } From 2e792619b6dfb3abc4ec5321615251bcaec4919c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 21 Apr 2019 13:03:28 +0200 Subject: [PATCH 08/36] Add note about pop-art normalization --- stable_baselines/ddpg/ddpg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index afc3352f6c..3c33d82f6a 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -154,7 +154,7 @@ class DDPG(OffPolicyRLModel): :param tau: (float) the soft update coefficient (keep old values, between 0 and 1) :param normalize_returns: (bool) should the critic output be normalized :param enable_popart: (bool) enable pop-art normalization of the critic output - (https://arxiv.org/pdf/1602.07714.pdf) + (https://arxiv.org/pdf/1602.07714.pdf), normalize_returns must be set to True. :param normalize_observations: (bool) should the observation be normalized :param batch_size: (int) the size of the batch for learning the policy :param observation_range: (tuple) the bounding values for the observation From a9f43afa1c15c17e8fbbaf90baf6d47992612df2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 22 Apr 2019 13:11:52 +0200 Subject: [PATCH 09/36] Add saving/loading + begin support for VecEnv --- stable_baselines/common/base_class.py | 33 ++++- .../common/vec_env/base_vec_env.py | 3 + .../common/vec_env/dummy_vec_env.py | 1 + stable_baselines/her/her.py | 116 ++++++++++++------ tests/test_her.py | 90 ++++++++++++++ 5 files changed, 203 insertions(+), 40 deletions(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index b2ae6952a9..755b20f173 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -2,6 +2,7 @@ import os import glob import warnings +from collections import OrderedDict import cloudpickle import numpy as np @@ -27,7 +28,7 @@ class BaseRLModel(ABC): """ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None): - if isinstance(policy, str): + if isinstance(policy, str) and policy_base is not None: self.policy = get_policy_from_name(policy_base, policy) else: self.policy = policy @@ -624,15 +625,39 @@ def __init__(self, venv): super().__init__(venv) assert venv.num_envs == 1, "Error: cannot unwrap a environment wrapper that has more than one environment." + def __getattr__(self, attr): + if attr in self.__dict__: + return getattr(self, attr) + return getattr(self.venv, attr) + + def __set_attr__(self, attr, value): + if attr in self.__dict__: + setattr(self, attr, value) + else: + set_attr(self.venv, attr, value) + + def compute_reward(self, achieved_goal, desired_goal, _info): + return float(self.venv.env_method('compute_reward', achieved_goal, desired_goal, _info)[0]) + + @staticmethod + def unvec_obs(obs): + if not isinstance(obs, dict): + return obs[0] + obs_ = OrderedDict() + for key in obs.keys(): + obs_[key] = obs[key][0] + del obs + return obs_ + def reset(self): - return self.venv.reset()[0] + return self.unvec_obs(self.venv.reset()) def step_async(self, actions): self.venv.step_async([actions]) def step_wait(self): - actions, values, states, information = self.venv.step_wait() - return actions[0], float(values[0]), states[0], information[0] + obs, rewards, dones, information = self.venv.step_wait() + return self.unvec_obs(obs), float(rewards[0]), dones[0], information[0] def render(self, mode='human'): return self.venv.render(mode=mode) diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index 1e51993b39..442fe29b7c 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -35,6 +35,9 @@ class VecEnv(ABC): :param observation_space: (Gym Space) the observation space :param action_space: (Gym Space) the action space """ + metadata = { + 'render.modes': ['human', 'rgb_array'] + } def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs diff --git a/stable_baselines/common/vec_env/dummy_vec_env.py b/stable_baselines/common/vec_env/dummy_vec_env.py index 9c7fc8ae83..52b3d8c940 100644 --- a/stable_baselines/common/vec_env/dummy_vec_env.py +++ b/stable_baselines/common/vec_env/dummy_vec_env.py @@ -26,6 +26,7 @@ def __init__(self, env_fns): self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_infos = [{} for _ in range(self.num_envs)] self.actions = None + self.metadata = env.metadata def step_async(self, actions): self.actions = actions diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 173d46571a..094b22ff2a 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -3,6 +3,8 @@ import gym from stable_baselines.common import BaseRLModel +from stable_baselines.common import OffPolicyRLModel +from stable_baselines.common.base_class import _UnvecWrapper from .replay_buffer import HindsightExperienceReplayWrapper, KEY_TO_GOAL_STRATEGY from .utils import HERGoalEnvWrapper @@ -21,36 +23,64 @@ class HER(BaseRLModel): def __init__(self, policy, env, model_class, n_sampled_goal=4, goal_selection_strategy='future', *args, **kwargs): - # super().__init__(policy=policy, env=env, verbose=verbose, policy_base=None, requires_vec_env=False) - # TODO: check if the env is not already wrapped + super().__init__(policy=policy, env=env, verbose=kwargs.get('verbose', 0), + policy_base=None, requires_vec_env=False) + self.model_class = model_class + self.replay_wrapper = None + + # Convert string to GoalSelectionStrategy object + if isinstance(goal_selection_strategy, str): + assert goal_selection_strategy in KEY_TO_GOAL_STRATEGY.keys(), "Unknown goal selection strategy" + goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy] + + self.n_sampled_goal = n_sampled_goal + self.goal_selection_strategy = goal_selection_strategy + + if self.env is not None: + self._create_replay_wrapper(self.env) + + assert issubclass(model_class, OffPolicyRLModel),\ + "Error: HER only works with Off policy model (such as DDPG, SAC and DQN)." + + self.model = self.model_class(policy, self.env, *args, **kwargs) + self.model._save_to_file = self._save_to_file + + + def _create_replay_wrapper(self, env): + # if isinstance(env, VecEnv): + # assert isinstance(env, _UnvecWrapper) + + # TODO: check if the env is not already wrapped + if not isinstance(env, HERGoalEnvWrapper): + env = HERGoalEnvWrapper(env) + self.env = env # TODO: check for TimeLimit wrapper too # TODO: support VecEnv # assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" - self.wrapped_env = HERGoalEnvWrapper(env) - - if isinstance(goal_selection_strategy, str): - assert goal_selection_strategy in KEY_TO_GOAL_STRATEGY.keys() - goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy] - self.replay_wrapper = functools.partial(HindsightExperienceReplayWrapper, n_sampled_goal=n_sampled_goal, - goal_selection_strategy=goal_selection_strategy, - wrapped_env=self.wrapped_env) - self.model = self.model_class(policy, self.wrapped_env, *args, **kwargs) + self.replay_wrapper = functools.partial(HindsightExperienceReplayWrapper, + n_sampled_goal=self.n_sampled_goal, + goal_selection_strategy=self.goal_selection_strategy, + wrapped_env=self.env) def set_env(self, env): - self.env = env - self.wrapped_env = HERGoalEnvWrapper(env) - self.model.set_env(self.wrapped_env) + # Unwrap VecEnv if needed + # TODO: save/load correct observation_space + # which is different between HER and the wrapped env + # super().set_env(env) + self._create_replay_wrapper(env) + self.model.set_env(self.env) def get_env(self): - return self.wrapped_env + return self.env def __getattr__(self, attr): """ Wrap the RL model. + :param attr: (str) :return: (Any) """ @@ -75,20 +105,31 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ return self.model.learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True, replay_wrapper=self.replay_wrapper) + def _check_obs(self, observation): + if isinstance(observation, dict): + if self.env is not None: + if len(observation['observation'].shape) > 1: + observation = _UnvecWrapper.unvec_obs(observation) + return [self.env.convert_dict_to_obs(observation)] + return self.env.convert_dict_to_obs(observation) + else: + raise ValueError("You must either pass an env to HER or wrap your env using HERGoalEnvWrapper") + return observation + def predict(self, observation, state=None, mask=None, deterministic=True): - # TODO: assert the type of observation - return self.model.predict(observation, state, mask, deterministic) + return self.model.predict(self._check_obs(observation), state, mask, deterministic) def action_probability(self, observation, state=None, mask=None, actions=None): - return self.model.action_probability(observation, state, mask, actions) + return self.model.action_probability(self._check_obs(observation), state, mask, actions) - # def _save_to_file(self, save_path, data=None, params=None): - # # HACK to save the replay wrapper - # # or better to save only the replay strategy and its params? - # # it will not work with VecEnv - # data['replay_wrapper'] = self.replay_wrapper - # data['model_class'] = self.model_class - # super()._save_to_file(save_path, data, params) + def _save_to_file(self, save_path, data=None, params=None): + # HACK to save the replay wrapper + # or better to save only the replay strategy and its params? + # it will not work with VecEnv + data['n_sampled_goal'] = self.n_sampled_goal + data['goal_selection_strategy'] = self.goal_selection_strategy + data['model_class'] = self.model_class + super()._save_to_file(save_path, data, params) def save(self, save_path): # Is there something more to save? (the replay wrapper?) @@ -96,16 +137,19 @@ def save(self, save_path): @classmethod def load(cls, load_path, env=None, **kwargs): - pass - # data, params = cls._load_from_file(load_path) - # - # if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: - # raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " - # "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], - # kwargs['policy_kwargs'])) - # - # model = cls(policy=data["policy"], env=env, model_class=data['model_class'], _init_setup_model=False) + data, _ = cls._load_from_file(load_path) + + if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: + raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. " + "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'], + kwargs['policy_kwargs'])) + + model = cls(policy=data["policy"], env=env, model_class=data['model_class'], + n_sampled_goal=data['n_sampled_goal'], + goal_selection_strategy=data['goal_selection_strategy'], + _init_setup_model=False) # model.__dict__.update(data) # model.__dict__.update(kwargs) - # model.model = data['model_class'].load(load_path, model.get_env()) - # return model + model.model = data['model_class'].load(load_path, model.get_env(), **kwargs) + model.model._save_to_file = model._save_to_file + return model diff --git a/tests/test_her.py b/tests/test_her.py index 91363134d8..139e11430f 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -1,11 +1,16 @@ +import os + import pytest from stable_baselines import HER, DQN, SAC, DDPG from stable_baselines.her import GoalSelectionStrategy +from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY from stable_baselines.common.bit_flipping_env import BitFlippingEnv +from stable_baselines.common.vec_env import DummyVecEnv N_BITS = 10 + @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) def test_her(model_class, goal_selection_strategy): @@ -14,3 +19,88 @@ def test_her(model_class, goal_selection_strategy): model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, verbose=0) model.learn(1000) + + +@pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]]) +@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) +def test_model_manipulation(model_class, goal_selection_strategy): + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + + model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, + verbose=0) + model.learn(1000) + + model.save('./test_her') + + obs = env.reset() + for _ in range(100): + action, _ = model.predict(obs) + obs, _, done, _ = env.step(action) + if done: + obs = env.reset() + + del model + + model = HER.load('./test_her') + model.set_env(env) + model.learn(1000) + + env = model.get_env() + obs = env.reset() + for _ in range(100): + action, _ = model.predict(obs) + obs, _, done, _ = env.step(action) + if done: + obs = env.reset() + + assert model.n_sampled_goal == 3 + + del model + + model = HER.load('./test_her', env=env) + model.learn(1000) + + assert model.n_sampled_goal == 3 + + if os.path.isfile('./test_her.pkl'): + os.remove('./test_her.pkl') + + +@pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]]) +@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) +def test_her_vec_env(model_class, goal_selection_strategy): + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + # Test vec env + env = DummyVecEnv([lambda: env]) + + model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, + verbose=0) + model.learn(1000) + + obs = env.reset() + for _ in range(100): + action, _ = model.predict(obs) + obs, _, done, _ = env.step(action) + if done: + obs = env.reset() + +# model.save('./test_her') +# del model +# +# # TODO: wrap into an _UnvecWrapper if needed +# # model = HER.load('./test_her') +# # model.set_env(env) +# # model.learn(1000) +# # +# # assert model.n_sampled_goal == 3 +# # +# # del model +# +# env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) +# model = HER.load('./test_her', env=env) +# model.learn(1000) +# +# assert model.n_sampled_goal == 3 +# +# if os.path.isfile('./test_her.pkl'): +# os.remove('./test_her.pkl') From ca32a5faf7c45cba534f077284a86ba24b6ca752 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 22 Apr 2019 23:24:46 +0200 Subject: [PATCH 10/36] Add success rate --- stable_baselines/common/bit_flipping_env.py | 3 ++- stable_baselines/ddpg/ddpg.py | 9 ++++++++- stable_baselines/deepq/dqn.py | 8 +++++++- stable_baselines/sac/sac.py | 7 +++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py index 571ed9d861..768b7d0820 100644 --- a/stable_baselines/common/bit_flipping_env.py +++ b/stable_baselines/common/bit_flipping_env.py @@ -58,8 +58,9 @@ def step(self, action): done = (obs['achieved_goal'] == obs['desired_goal']).all() self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps + info = {'is_success': done} done = done or self.current_step >= self.max_steps - return obs, reward, done, {} + return obs, reward, done, info def compute_reward(self, achieved_goal, desired_goal, _info): # Deceptive reward: it is positive only when the goal is achieved diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 2d54e90952..e2377fa937 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -808,6 +808,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ eval_episode_rewards_history = deque(maxlen=100) episode_rewards_history = deque(maxlen=100) self.episode_reward = np.zeros((1,)) + episode_successes = [] with self.sess.as_default(), self.graph.as_default(): # Prepare everything. self._reset() @@ -848,7 +849,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # Execute next action. if rank == 0 and self.render: self.env.render() - new_obs, reward, done, _ = self.env.step(action * np.abs(self.action_space.low)) + new_obs, reward, done, info = self.env.step(action * np.abs(self.action_space.low)) if writer is not None: ep_rew = np.array([reward]).reshape((1, -1)) @@ -884,6 +885,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ epoch_episodes += 1 episodes += 1 + maybe_is_success = info.get('is_success') + if maybe_is_success is not None: + episode_successes.append(float(maybe_is_success)) + self._reset() if not isinstance(self.env, VecEnv): obs = self.env.reset() @@ -985,6 +990,8 @@ def as_scalar(scalar): for key in sorted(combined_stats.keys()): logger.record_tabular(key, combined_stats[key]) + if len(episode_successes) > 0: + logger.logkv("success rate", np.mean(episode_successes[-100:])) logger.dump_tabular() logger.info('') logdir = logger.get_dir() diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 86d7ba4cdc..48165f4206 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -176,6 +176,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ final_p=self.exploration_final_eps) episode_rewards = [0.0] + episode_successes = [] obs = self.env.reset() reset = True self.episode_reward = np.zeros((1,)) @@ -207,7 +208,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ action = self.act(np.array(obs)[None], update_eps=update_eps, **kwargs)[0] env_action = action reset = False - new_obs, rew, done, _ = self.env.step(env_action) + new_obs, rew, done, info = self.env.step(env_action) # Store transition in the replay buffer. self.replay_buffer.add(obs, action, rew, new_obs, float(done)) obs = new_obs @@ -220,6 +221,9 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ episode_rewards[-1] += rew if done: + maybe_is_success = info.get('is_success') + if maybe_is_success is not None: + episode_successes.append(float(maybe_is_success)) if not isinstance(self.env, VecEnv): obs = self.env.reset() episode_rewards.append(0.0) @@ -271,6 +275,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0: logger.record_tabular("steps", self.num_timesteps) logger.record_tabular("episodes", num_episodes) + if len(episode_successes) > 0: + logger.logkv("success rate", np.mean(episode_successes[-100:])) logger.record_tabular("mean 100 episode reward", mean_100ep_reward) logger.record_tabular("% time spent exploring", int(100 * self.exploration.value(self.num_timesteps))) diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 41c87c0b5d..77b6888a05 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -375,6 +375,7 @@ def learn(self, total_timesteps, callback=None, seed=None, start_time = time.time() episode_rewards = [0.0] + episode_successes = [] if self.action_noise is not None: self.action_noise.reset() obs = self.env.reset() @@ -454,6 +455,10 @@ def learn(self, total_timesteps, callback=None, seed=None, obs = self.env.reset() episode_rewards.append(0.0) + maybe_is_success = info.get('is_success') + if maybe_is_success is not None: + episode_successes.append(float(maybe_is_success)) + if len(episode_rewards[-101:-1]) == 0: mean_reward = -np.inf else: @@ -473,6 +478,8 @@ def learn(self, total_timesteps, callback=None, seed=None, logger.logkv("current_lr", current_lr) logger.logkv("fps", fps) logger.logkv('time_elapsed', int(time.time() - start_time)) + if len(episode_successes) > 0: + logger.logkv("success rate", np.mean(episode_successes[-100:])) if len(infos_values) > 0: for (name, val) in zip(self.infos_names, infos_values): logger.logkv(name, val) From 8023bbce9bc6bbc194130cb5b596d014e3e69dd7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 23 Apr 2019 20:48:55 +0200 Subject: [PATCH 11/36] Fix HER learning method --- stable_baselines/her/her.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 094b22ff2a..32c98e4328 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -102,8 +102,9 @@ def setup_model(self): def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True): - return self.model.learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", - reset_num_timesteps=True, replay_wrapper=self.replay_wrapper) + return self.model.learn(total_timesteps, callback=callback, seed=seed, log_interval=log_interval, + tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, + replay_wrapper=self.replay_wrapper) def _check_obs(self, observation): if isinstance(observation, dict): From 09e514d844dbd84261a70685600686c6360a9db2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 27 Apr 2019 23:32:02 +0200 Subject: [PATCH 12/36] Add support for VecEnv + improve comments + add properties to ReplayBuffer --- docs/modules/her.rst | 47 +----------- stable_baselines/deepq/replay_buffer.py | 22 +++++- stable_baselines/her/__init__.py | 1 + stable_baselines/her/her.py | 23 +++--- stable_baselines/her/replay_buffer.py | 14 +--- stable_baselines/her/utils.py | 17 +++-- tests/test_her.py | 99 +++++++++++-------------- 7 files changed, 95 insertions(+), 128 deletions(-) diff --git a/docs/modules/her.rst b/docs/modules/her.rst index bbbdcc9552..c77ac35733 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -8,56 +8,15 @@ HER `Hindsight Experience Replay (HER) `_ -.. warning:: +.. note:: - HER is not refactored yet. We are looking for contributors to help us. + HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines -How to use Hindsight Experience Replay --------------------------------------- -Getting started -~~~~~~~~~~~~~~~ +`Plappert et al. (2018)`_ -Training an agent is very simple: - -.. code:: bash - - python -m stable_baselines.her.experiment.train - -This will train a DDPG+HER agent on the ``FetchReach`` environment. You -should see the success rate go up quickly to ``1.0``, which means that -the agent achieves the desired goal in 100% of the cases. The training -script logs other diagnostics as well and pickles the best policy so far -(w.r.t. to its test success rate), the latest policy, and, if enabled, a -history of policies every K epochs. - -To inspect what the agent has learned, use the play script: - -.. code:: bash - - python -m stable_baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl - -You can try it right now with the results of the training step (the -script prints out the path for you). This should visualize the current -policy for 10 episodes and will also print statistics. - -Reproducing results -~~~~~~~~~~~~~~~~~~~ - -In order to reproduce the results from `Plappert et al. (2018)`_, run -the following command: - -.. code:: bash - - python -m stable_baselines.her.experiment.train --num_cpu 19 - -This will require a machine with sufficient amount of physical CPU -cores. In our experiments, we used `Azure's D15v2 instances`_, which -have 20 physical cores. We only scheduled the experiment on 19 of those -to leave some head-room on the system. .. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464 -.. _Azure's D15v2 instances: https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes Parameters diff --git a/stable_baselines/deepq/replay_buffer.py b/stable_baselines/deepq/replay_buffer.py index 3cb2c0d3a2..d6ea165a32 100644 --- a/stable_baselines/deepq/replay_buffer.py +++ b/stable_baselines/deepq/replay_buffer.py @@ -8,7 +8,7 @@ class ReplayBuffer(object): def __init__(self, size): """ - Create Replay buffer. + Implements a ring buffer (FIFO). :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. @@ -20,6 +20,24 @@ def __init__(self, size): def __len__(self): return len(self._storage) + @property + def storage(self): + """[(np.ndarray, float, float, np.ndarray, bool)]: content of the replay buffer""" + return self._storage + + @property + def buffer_size(self): + """float: Max capacity of the buffer""" + return self._maxsize + + def is_full(self): + """ + Check whether the replay buffer is full or not. + + :return: (bool) + """ + return len(self) == self.buffer_size + def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer @@ -63,6 +81,8 @@ def sample(self, batch_size, **_kwargs): - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ + # TODO(araffin): should we ensure sample with no replacement? + # using np.random.choice() idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] return self._encode_sample(idxes) diff --git a/stable_baselines/her/__init__.py b/stable_baselines/her/__init__.py index 0c7dd57c64..c6f148b711 100644 --- a/stable_baselines/her/__init__.py +++ b/stable_baselines/her/__init__.py @@ -1,2 +1,3 @@ from stable_baselines.her.her import HER from stable_baselines.her.replay_buffer import GoalSelectionStrategy +from stable_baselines.her.utils import HERGoalEnvWrapper diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 32c98e4328..87698171b5 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -29,6 +29,10 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, self.model_class = model_class self.replay_wrapper = None + # Save dict observation space (used for checks at loading time) + if env is not None: + self.observation_space = env.observation_space + self.action_space = env.action_space # Convert string to GoalSelectionStrategy object if isinstance(goal_selection_strategy, str): @@ -49,16 +53,12 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, def _create_replay_wrapper(self, env): - # if isinstance(env, VecEnv): - # assert isinstance(env, _UnvecWrapper) - - # TODO: check if the env is not already wrapped if not isinstance(env, HERGoalEnvWrapper): env = HERGoalEnvWrapper(env) self.env = env - # TODO: check for TimeLimit wrapper too - # TODO: support VecEnv + # NOTE: we cannot do that check directly with VecEnv + # maybe we can try calling `compute_reward()` ? # assert isinstance(self.env, gym.GoalEnv), "HER only supports gym.GoalEnv" self.replay_wrapper = functools.partial(HindsightExperienceReplayWrapper, @@ -67,11 +67,8 @@ def _create_replay_wrapper(self, env): wrapped_env=self.env) def set_env(self, env): - # Unwrap VecEnv if needed - # TODO: save/load correct observation_space - # which is different between HER and the wrapped env - # super().set_env(env) - self._create_replay_wrapper(env) + super().set_env(env) + self._create_replay_wrapper(self.env) self.model.set_env(self.env) def get_env(self): @@ -130,6 +127,8 @@ def _save_to_file(self, save_path, data=None, params=None): data['n_sampled_goal'] = self.n_sampled_goal data['goal_selection_strategy'] = self.goal_selection_strategy data['model_class'] = self.model_class + data['her_obs_space'] = self.observation_space + data['her_action_space'] = self.action_space super()._save_to_file(save_path, data, params) def save(self, save_path): @@ -149,6 +148,8 @@ def load(cls, load_path, env=None, **kwargs): n_sampled_goal=data['n_sampled_goal'], goal_selection_strategy=data['goal_selection_strategy'], _init_setup_model=False) + model.__dict__['observation_space'] = data['her_obs_space'] + model.__dict__['action_space'] = data['her_action_space'] # model.__dict__.update(data) # model.__dict__.update(kwargs) model.model = data['model_class'].load(load_path, model.get_env(), **kwargs) diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index d603b9eb2a..8729e4c5b2 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -60,9 +60,6 @@ def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapp self.episode_transitions = [] self.replay_buffer = replay_buffer - def append(self, obs_t, action, reward, obs_tp1, done): - return self.add(obs_t, action, reward, obs_tp1, done) - def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer @@ -104,13 +101,13 @@ def _sample_achieved_goal(self, episode_transitions, transition_idx): # Choose the goal achieved at the end of the episode selected_transition = episode_transitions[-1] elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: - # Random achieved goal during the episode + # Random goal achieved during the episode selected_idx = np.random.choice(np.arange(len(episode_transitions))) selected_transition = episode_transitions[selected_idx] elif self.goal_selection_strategy == GoalSelectionStrategy.RANDOM: - # Random achieved goal from the entire replay buffer + # Random goal achieved, from the entire replay buffer selected_idx = np.random.choice(np.arange(len(self.replay_buffer))) - selected_transition = self.replay_buffer._storage[selected_idx] + selected_transition = self.replay_buffer.storage[selected_idx] else: raise ValueError("Invalid goal selection strategy," "please use one of {}".format(list(GoalSelectionStrategy))) @@ -135,9 +132,6 @@ def _store_episode(self): episode in the replay buffer. This method is called only after each end of episode. """ - # NOTE: is deepcopy really needed here? - # last_episode_transitions = copy.deepcopy(self.episode_transitions) - # For each transition in the last episode, # create a set of artificial transitions for transition_idx, transition in enumerate(self.episode_transitions): @@ -168,7 +162,7 @@ def _store_episode(self): # Update the reward according to the new desired goal reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal'], None) - # Can we ensure that done = reward == 0 + # Can we ensure that done = reward == 0? done = False # Transform back to ndarrays diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index 38b1712123..00524e7abd 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -4,8 +4,9 @@ class HERGoalEnvWrapper(object): """ - A wrapper that allow to use dict env (coming from GoalEnv) with + A wrapper that allow to use dict observation space (coming from GoalEnv) with the RL algorithms. + It assumes that all the spaces of the dict space are of the same type. :param env: (gym.GoalEnv) """ @@ -17,7 +18,7 @@ def __init__(self, env): self.action_space = env.action_space self.spaces = list(env.observation_space.spaces.values()) # TODO: check that all spaces are of the same type - # (current limiation of the wrapper) + # (current limitation of the wrapper) # TODO: check when dim > 1 goal_space_shape = env.observation_space.spaces['achieved_goal'].shape @@ -26,18 +27,21 @@ def __init__(self, env): total_dim = self.obs_dim + 2 * self.goal_dim if len(goal_space_shape) == 2: - assert goal_space_shape[1] == 1 + assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet" else: - assert len(goal_space_shape) == 1 + assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet" if isinstance(self.spaces[0], spaces.MultiBinary): self.observation_space = spaces.MultiBinary(total_dim) + elif isinstance(self.spaces[0], spaces.Box): lows = np.concatenate([space.low for space in self.spaces]) highs = np.concatenate([space.high for space in self.spaces]) self.observation_space = spaces.Box(lows, highs, dtype=np.float32) + elif isinstance(self.spaces[0], spaces.Discrete): pass + else: raise NotImplementedError() @@ -47,11 +51,14 @@ def convert_dict_to_obs(obs_dict): :param obs_dict: (dict) :return: (np.ndarray) """ - # Note: we should remove achieved goal from the observation ? + # Note: achieved goal is not removed from the observation + # this is helpful to have a revertible transformation return np.concatenate([obs for obs in obs_dict.values()]) def convert_obs_to_dict(self, observations): """ + Inverse operation of convert_dict_to_obs + :param observations: (np.ndarray) :return: (dict) """ diff --git a/tests/test_her.py b/tests/test_her.py index 139e11430f..2f329f0e60 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -3,14 +3,34 @@ import pytest from stable_baselines import HER, DQN, SAC, DDPG -from stable_baselines.her import GoalSelectionStrategy +from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY from stable_baselines.common.bit_flipping_env import BitFlippingEnv -from stable_baselines.common.vec_env import DummyVecEnv +from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize N_BITS = 10 +def model_predict(model, env, n_steps, additional_check=None): + """ + Test helper + :param model: (rl model) + :param env: (gym.Env) + :param n_steps: (int) + :param additional_check: (callable) + """ + obs = env.reset() + for _ in range(n_steps): + action, _ = model.predict(obs) + obs, reward, done, _ = env.step(action) + + if additional_check is not None: + additional_check(obs, action, reward, done) + + if done: + obs = env.reset() + + @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) def test_her(model_class, goal_selection_strategy): @@ -25,82 +45,47 @@ def test_her(model_class, goal_selection_strategy): @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) def test_model_manipulation(model_class, goal_selection_strategy): env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + env = DummyVecEnv([lambda: env]) + # NOTE: HER does not support VecEnvWrapper yet + # env = VecNormalize(env) model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, verbose=0) model.learn(1000) - model.save('./test_her') - - obs = env.reset() - for _ in range(100): - action, _ = model.predict(obs) - obs, _, done, _ = env.step(action) - if done: - obs = env.reset() + model_predict(model, env, n_steps=100, additional_check=None) + model.save('./test_her') del model model = HER.load('./test_her') + + # Check that the model raises an error when the env + # is not wrapped (or no env passed to the model) + with pytest.raises(ValueError): + model.predict(env.reset()) + + env_ = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + env_ = HERGoalEnvWrapper(env_) + + model_predict(model, env_, n_steps=100, additional_check=None) + model.set_env(env) model.learn(1000) - env = model.get_env() - obs = env.reset() - for _ in range(100): - action, _ = model.predict(obs) - obs, _, done, _ = env.step(action) - if done: - obs = env.reset() + model_predict(model, env_, n_steps=100, additional_check=None) assert model.n_sampled_goal == 3 del model + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) model = HER.load('./test_her', env=env) model.learn(1000) + model_predict(model, env_, n_steps=100, additional_check=None) + assert model.n_sampled_goal == 3 if os.path.isfile('./test_her.pkl'): os.remove('./test_her.pkl') - - -@pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]]) -@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) -def test_her_vec_env(model_class, goal_selection_strategy): - env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) - # Test vec env - env = DummyVecEnv([lambda: env]) - - model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, - verbose=0) - model.learn(1000) - - obs = env.reset() - for _ in range(100): - action, _ = model.predict(obs) - obs, _, done, _ = env.step(action) - if done: - obs = env.reset() - -# model.save('./test_her') -# del model -# -# # TODO: wrap into an _UnvecWrapper if needed -# # model = HER.load('./test_her') -# # model.set_env(env) -# # model.learn(1000) -# # -# # assert model.n_sampled_goal == 3 -# # -# # del model -# -# env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) -# model = HER.load('./test_her', env=env) -# model.learn(1000) -# -# assert model.n_sampled_goal == 3 -# -# if os.path.isfile('./test_her.pkl'): -# os.remove('./test_her.pkl') From c6479e44ff673a48b7c8f76c6eaaf52414eea3a0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 28 Apr 2019 00:12:52 +0200 Subject: [PATCH 13/36] Update documentation --- README.md | 30 ++++++------ docs/guide/algos.rst | 23 ++++------ docs/guide/examples.rst | 5 ++ docs/modules/her.rst | 78 +++++++++++++++++++++++++++++++- stable_baselines/ddpg/ddpg.py | 4 ++ stable_baselines/her/__init__.py | 2 +- stable_baselines/her/her.py | 11 ++--- tests/test_her.py | 2 +- 8 files changed, 116 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 0f2d3307ae..c1f3a4a4c6 100644 --- a/README.md +++ b/README.md @@ -28,14 +28,14 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring, | Common interface | :heavy_check_mark: | :heavy_minus_sign: (3) | | Tensorboard support | :heavy_check_mark: | :heavy_minus_sign: (4) | | Ipython / Notebook friendly | :heavy_check_mark: | :x: | -| PEP8 code style | :heavy_check_mark: | :heavy_minus_sign: (5) | +| PEP8 code style | :heavy_check_mark: | :heavy_check_mark: (5) | | Custom callback | :heavy_check_mark: | :heavy_minus_sign: (6) | -(1): Forked from previous version of OpenAI baselines, however missing refactoring for HER.
+(1): Forked from previous version of OpenAI baselines, with now SAC in addition
(2): Currently not available for DDPG, and only from the run script.
(3): Only via the run script.
(4): Rudimentary logging of training information (no loss nor graph).
-(5): WIP on OpenAI's side (you can do it OpenAI! :cat:)
+(5): EDIT: you did it OpenAI! :cat:
(6): Passing a callback function is only available for DQN
## Documentation @@ -143,13 +143,13 @@ All the following examples can be executed online using Google colab notebooks: | **Name** | **Refactored**(1) | **Recurrent** | ```Box``` | ```Discrete``` | ```MultiDiscrete``` | ```MultiBinary``` | **Multi Processing** | | ------------------- | ---------------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | -| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | ACER | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | ACKTR | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | DQN | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: | | GAIL (2) | :heavy_check_mark: | :x: | :heavy_check_mark: |:heavy_check_mark:| :x: | :x: | :heavy_check_mark: (4) | -| HER (3) | :x: (5) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | +| HER (3) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark:| :x: | | PPO1 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: (4) | | PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | @@ -157,11 +157,11 @@ All the following examples can be executed online using Google colab notebooks: (1): Whether or not the algorithm has be refactored to fit the ```BaseRLModel``` class.
(2): Only implemented for TRPO.
-(3): Only implemented for DDPG.
+(3): Re-implemented from scratch
(4): Multi Processing with [MPI](https://mpi4py.readthedocs.io/en/stable/).
(5): TODO, in project scope. -NOTE: Soft Actor-Critic (SAC) was not part of the original baselines. +NOTE: Soft Actor-Critic (SAC) was not part of the original baselines and HER was reimplemented from scratch. Actions ```gym.spaces```: * ```Box```: A N-dimensional box that containes every point in the action space. @@ -190,14 +190,14 @@ please tell us when if you want your project to appear on this page ;) To cite this repository in publications: ``` - @misc{stable-baselines, - author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, - title = {Stable Baselines}, - year = {2018}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\url{https://github.com/hill-a/stable-baselines}}, - } +@misc{stable-baselines, + author = {Hill, Ashley and Raffin, Antonin and Ernestus, Maximilian and Gleave, Adam and Traore, Rene and Dhariwal, Prafulla and Hesse, Christopher and Klimov, Oleg and Nichol, Alex and Plappert, Matthias and Radford, Alec and Schulman, John and Sidor, Szymon and Wu, Yuhuai}, + title = {Stable Baselines}, + year = {2018}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/hill-a/stable-baselines}}, +} ``` ## Maintainers diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index a39852edd4..f2967cdfe7 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -11,34 +11,31 @@ along with some useful characteristics: support for recurrent policies, discrete .. A2C ✔️ .. ===== ======================== ========= ======= ============ ================= =============== ================ -.. There is an issue with Read The Docs for building the table when the "HER" row is present: -.. Apparently a problem of spacing -.. HER [#f3]_ ❌ [#f5]_ ❌ ✔️ ❌ ❌ - ============ ======================== ========= =========== ============ ================ Name Refactored [#f1]_ Recurrent ``Box`` ``Discrete`` Multi Processing ============ ======================== ========= =========== ============ ================ A2C ✔️ ✔️ ✔️ ✔️ ✔️ -ACER ✔️ ✔️ ❌ [#f5]_ ✔️ ✔️ -ACKTR ✔️ ✔️ ❌ [#f5]_ ✔️ ✔️ +ACER ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️ +ACKTR ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️ DDPG ✔️ ❌ ✔️ ❌ ❌ DQN ✔️ ❌ ❌ ✔️ ❌ -GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️ ✔️ [#f4]_ -PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f4]_ +HER ✔️ ❌ ✔️ ✔️ ❌ +GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️ ✔️ [#f3]_ +PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_ PPO2 ✔️ ✔️ ✔️ ✔️ ✔️ SAC ✔️ ❌ ✔️ ❌ ❌ -TRPO ✔️ ❌ ✔️ ✔️ ✔️ [#f4]_ +TRPO ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_ ============ ======================== ========= =========== ============ ================ .. [#f1] Whether or not the algorithm has be refactored to fit the ``BaseRLModel`` class. .. [#f2] Only implemented for TRPO. -.. [#f3] Only implemented for DDPG. -.. [#f4] Multi Processing with `MPI`_. -.. [#f5] TODO, in project scope. +.. [#f3] Multi Processing with `MPI`_. +.. [#f4] TODO, in project scope. .. note:: - Non-array spaces such as `Dict` or `Tuple` are not currently supported by any algorithm. + Non-array spaces such as `Dict` or `Tuple` are not currently supported by any algorithm, + except HER for dict when working with gym.GoalEnv Actions ``gym.spaces``: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index df4a81cdef..05da2186ae 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -349,6 +349,11 @@ This example demonstrate how to train a recurrent policy and how to test it prop env.render() +Hindsight Experience Replay (HER) +--------------------------------- + +TODO: highway env + Continual Learning ------------------ diff --git a/docs/modules/her.rst b/docs/modules/her.rst index c77ac35733..fa237e806b 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -8,19 +8,95 @@ HER `Hindsight Experience Replay (HER) `_ +HER is a method wrapper that works with Off policy methods (DQN, SAC and DDPG for example). + .. note:: HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines +.. warning:: + + HER requires the environment to inherits from `gym.GoalEnv `_ -`Plappert et al. (2018)`_ +Notes +----- + +- Original paper: https://arxiv.org/abs/1707.01495 +- OpenAI paper: `Plappert et al. (2018)`_ +- OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/ .. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464 +Can I use? +---------- + +Please refer to the wrapped model (DQN, SAC or DDPG) for that section. + +Example +------- + +.. code-block:: python + + from stable_baselines import HER, DQN, SAC, DDPG + from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper + from stable_baselines.common.bit_flipping_env import BitFlippingEnv + + model_class = DQN # works also with SAC and DDPG + + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + + # Available strategies (cf paper): future, final, episode, random + goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE + + # Wrap the model + model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, + verbose=1) + # Train the model + model.learn(1000) + + model.save("./her_bit_env") + + # WARNING: you must pass an env + # or wrap your environment with HERGoalEnvWrapper to use the predict method + model = HER.load('./her_bit_env', env=env) + + obs = env.reset() + for _ in range(100): + action, _ = model.predict(obs) + obs, reward, done, _ = env.step(action) + + if done: + obs = env.reset() + Parameters ---------- .. autoclass:: HER :members: + +Goal Selection Strategies +------------------------- + +.. autoclass:: GoalSelectionStrategy + :members: + :inherited-members: + :undoc-members: + + +Gaol Env Wrapper +---------------- + +.. autoclass:: HERGoalEnvWrapper + :members: + :inherited-members: + :undoc-members: + + +Replay Wrapper +-------------- + +.. autoclass:: HindsightExperienceReplayWrapper + :members: + :inherited-members: diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index e2377fa937..97ee031130 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -142,8 +142,10 @@ class DDPG(OffPolicyRLModel): :param gamma: (float) the discount factor :param memory_policy: (ReplayBuffer) the replay buffer (if None, default to baselines.deepq.replay_buffer.ReplayBuffer) + .. deprecated:: 2.6.0 This parameter will be removed in a future version + :param eval_env: (Gym Environment) the evaluation environment (can be None) :param nb_train_steps: (int) the number of training steps :param nb_rollout_steps: (int) the number of rollout steps @@ -167,8 +169,10 @@ class DDPG(OffPolicyRLModel): :param render: (bool) enable rendering of the environment :param render_eval: (bool) enable rendering of the evalution environment :param memory_limit: (int) the max number of transitions to store, size of the replay buffer + .. deprecated:: 2.6.0 Use `buffer_size` instead. + :param buffer_size: (int) the max number of transitions to store, size of the replay buffer :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) diff --git a/stable_baselines/her/__init__.py b/stable_baselines/her/__init__.py index c6f148b711..6b47d6e794 100644 --- a/stable_baselines/her/__init__.py +++ b/stable_baselines/her/__init__.py @@ -1,3 +1,3 @@ from stable_baselines.her.her import HER -from stable_baselines.her.replay_buffer import GoalSelectionStrategy +from stable_baselines.her.replay_buffer import GoalSelectionStrategy, HindsightExperienceReplayWrapper from stable_baselines.her.utils import HERGoalEnvWrapper diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 87698171b5..dace5a0497 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -1,7 +1,5 @@ import functools -import gym - from stable_baselines.common import BaseRLModel from stable_baselines.common import OffPolicyRLModel from stable_baselines.common.base_class import _UnvecWrapper @@ -45,13 +43,13 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, if self.env is not None: self._create_replay_wrapper(self.env) - assert issubclass(model_class, OffPolicyRLModel),\ + assert issubclass(model_class, OffPolicyRLModel), \ "Error: HER only works with Off policy model (such as DDPG, SAC and DQN)." self.model = self.model_class(policy, self.env, *args, **kwargs) + # Patch to support saving/loading self.model._save_to_file = self._save_to_file - def _create_replay_wrapper(self, env): if not isinstance(env, HERGoalEnvWrapper): env = HERGoalEnvWrapper(env) @@ -89,7 +87,7 @@ def __set_attr__(self, attr, value): if attr in self.__dict__: setattr(self, attr, value) else: - set_attr(self.model, attr, value) + setattr(self.model, attr, value) def _get_pretrain_placeholders(self): return self.model._get_pretrain_placeholders() @@ -132,7 +130,6 @@ def _save_to_file(self, save_path, data=None, params=None): super()._save_to_file(save_path, data, params) def save(self, save_path): - # Is there something more to save? (the replay wrapper?) self.model.save(save_path) @classmethod @@ -150,8 +147,6 @@ def load(cls, load_path, env=None, **kwargs): _init_setup_model=False) model.__dict__['observation_space'] = data['her_obs_space'] model.__dict__['action_space'] = data['her_action_space'] - # model.__dict__.update(data) - # model.__dict__.update(kwargs) model.model = data['model_class'].load(load_path, model.get_env(), **kwargs) model.model._save_to_file = model._save_to_file return model diff --git a/tests/test_her.py b/tests/test_her.py index 2f329f0e60..8ff8166f25 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -6,7 +6,7 @@ from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY from stable_baselines.common.bit_flipping_env import BitFlippingEnv -from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize +from stable_baselines.common.vec_env import DummyVecEnv # , VecNormalize N_BITS = 10 From c72e760afe7b920e6b69962d58a7d06ba5850a62 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 28 Apr 2019 15:01:02 +0200 Subject: [PATCH 14/36] Add HER example --- docs/guide/examples.rst | 69 ++++++++++++++++++++++++++- docs/modules/her.rst | 6 +++ stable_baselines/common/base_class.py | 6 ++- stable_baselines/her/her.py | 4 ++ 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 05da2186ae..8fa10a7986 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -352,7 +352,74 @@ This example demonstrate how to train a recurrent policy and how to test it prop Hindsight Experience Replay (HER) --------------------------------- -TODO: highway env +For this example, we are using `Highway-Env `_ by `@eleurent `_. + +.. figure:: https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif + + The highway-parking-v0 environment. + +The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading. + +.. note:: + + the hyperparameters in the following example were optimized for that environment. + + +.. code-block:: python + + import gym + import highway_env + import numpy as np + + from stable_baselines import HER, SAC, DDPG + from stable_baselines.ddpg import NormalActionNoise + + env = gym.make("highway-parking-v0") + + # Create 4 artificial transition per real transition + n_sampled_goal = 4 + + # SAC hyperparams: + model = HER('MlpPolicy', env, SAC, n_sampled_goal=n_sampled_goal, + goal_selection_strategy='future', + verbose=1, buffer_size=int(1e6), + learning_rate=1e-3, + gamma=0.95, batch_size=256, + policy_kwargs=dict(layers=[256, 256, 256])) + + # DDPG Hyperparams: + # NOTE: it works even without action noise + # n_actions = env.action_space.shape[0] + # noise_std = 0.2 + # action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=noise_std * np.ones(n_actions)) + # model = HER('MlpPolicy', env, DDPG, n_sampled_goal=n_sampled_goal, + # goal_selection_strategy='future', + # verbose=1, buffer_size=int(1e6), + # actor_lr=1e-3, critic_lr=1e-3, action_noise=action_noise, + # gamma=0.95, batch_size=256, + # policy_kwargs=dict(layers=[256, 256, 256])) + + + model.learn(int(2e5)) + model.save('her_sac_highway') + + # Load saved model + model = HER.load('her_sac_highway', env=env) + + obs = env.reset() + + # Evaluate the agent + episode_reward = 0 + for _ in range(100): + action, _ = model.predict(obs) + obs, reward, done, info = env.step(action) + env.render() + episode_reward += reward + if done or info.get('is_success', False): + print("Reward:", episode_reward, "Success?", info.get('is_success', False)) + episode_reward = 0.0 + obs = env.reset() + Continual Learning diff --git a/docs/modules/her.rst b/docs/modules/her.rst index fa237e806b..505d2af87f 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -18,6 +18,12 @@ HER is a method wrapper that works with Off policy methods (DQN, SAC and DDPG fo HER requires the environment to inherits from `gym.GoalEnv `_ + +.. warning:: + + you must pass an environment or wrap it with ``HERGoalEnvWrapper`` in order to use the predict method + + Notes ----- diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 755b20f173..0da7631e25 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -634,13 +634,17 @@ def __set_attr__(self, attr, value): if attr in self.__dict__: setattr(self, attr, value) else: - set_attr(self.venv, attr, value) + setattr(self.venv, attr, value) def compute_reward(self, achieved_goal, desired_goal, _info): return float(self.venv.env_method('compute_reward', achieved_goal, desired_goal, _info)[0]) @staticmethod def unvec_obs(obs): + """ + :param obs: (Union[np.ndarray, dict]) + :return: (Union[np.ndarray, dict]) + """ if not isinstance(obs, dict): return obs[0] obs_ = OrderedDict() diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index dace5a0497..dc1d3cdb34 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -51,6 +51,10 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4, self.model._save_to_file = self._save_to_file def _create_replay_wrapper(self, env): + """ + Wrap the environment in a HERGoalEnvWrapper + if needed and create the replay buffer wrapper. + """ if not isinstance(env, HERGoalEnvWrapper): env = HERGoalEnvWrapper(env) From 88cb4e5cdcf846c45c48f43de92a12f98f109193 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 4 May 2019 11:20:46 +0200 Subject: [PATCH 15/36] Removed unused dependencies (tdqm, dill, progressbar2, seaborn, glob2, click) --- docs/misc/changelog.rst | 1 + setup.py | 10 ++-------- stable_baselines/__init__.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e79a134d12..cf00c66563 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ Pre-Release 2.6.0a0 (WIP) - **deprecated** ``memory_limit`` and ``memory_policy`` in DDPG, please use ``buffer_size`` instead. (will be removed in v3.x.x) - removed ``stable_baselines.ddpg.memory`` in favor of ``stable_baselines.deepq.replay_buffer`` - add ``action_noise`` param for SAC, it helps exploration for problem with deceptive reward +- removed unused dependencies (tdqm, dill, progressbar2, seaborn, glob2, click) Release 2.5.1 (2019-05-04) diff --git a/setup.py b/setup.py index 3e2bd445c4..dea2d56728 100644 --- a/setup.py +++ b/setup.py @@ -107,19 +107,13 @@ install_requires=[ 'gym[atari,classic_control]>=0.10.9', 'scipy', - 'tqdm', 'joblib', - 'zmq', - 'dill', 'mpi4py', 'cloudpickle>=0.5.5', - 'click', 'opencv-python', 'numpy', 'pandas', - 'matplotlib', - 'seaborn', - 'glob2' + 'matplotlib' ] + tf_dependency, extras_require={ 'tests': [ @@ -143,7 +137,7 @@ license="MIT", long_description=long_description, long_description_content_type='text/markdown', - version="2.5.1", + version="2.6.0a0", ) # python setup.py sdist diff --git a/stable_baselines/__init__.py b/stable_baselines/__init__.py index b48e951b4f..4f32d2ca7b 100644 --- a/stable_baselines/__init__.py +++ b/stable_baselines/__init__.py @@ -10,4 +10,4 @@ from stable_baselines.trpo_mpi import TRPO from stable_baselines.sac import SAC -__version__ = "2.5.1" +__version__ = "2.6.0a0" From 6c7f5bb8e13e91a215f52e46962a5ba4a6a17f5a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 4 May 2019 11:23:01 +0200 Subject: [PATCH 16/36] Remove note on the replay buffer --- stable_baselines/deepq/replay_buffer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/stable_baselines/deepq/replay_buffer.py b/stable_baselines/deepq/replay_buffer.py index d6ea165a32..a8771a7570 100644 --- a/stable_baselines/deepq/replay_buffer.py +++ b/stable_baselines/deepq/replay_buffer.py @@ -81,8 +81,6 @@ def sample(self, batch_size, **_kwargs): - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ - # TODO(araffin): should we ensure sample with no replacement? - # using np.random.choice() idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] return self._encode_sample(idxes) From 65d21e2bbed666dc60f2cef6bc45a538246f071b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 5 May 2019 15:29:11 +0200 Subject: [PATCH 17/36] Update doc + add a check for VecEnvWrapper with HER --- docs/guide/examples.rst | 4 ++-- stable_baselines/her/her.py | 4 ++++ stable_baselines/her/replay_buffer.py | 4 ++-- tests/test_her.py | 8 +++++--- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 8fa10a7986..6a8adb71e5 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -374,9 +374,9 @@ The parking env is a goal-conditioned continuous control task, in which the vehi from stable_baselines import HER, SAC, DDPG from stable_baselines.ddpg import NormalActionNoise - env = gym.make("highway-parking-v0") + env = gym.make("parking-v0") - # Create 4 artificial transition per real transition + # Create 4 artificial transitions per real transition n_sampled_goal = 4 # SAC hyperparams: diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index dc1d3cdb34..23a775e2ac 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -3,6 +3,7 @@ from stable_baselines.common import BaseRLModel from stable_baselines.common import OffPolicyRLModel from stable_baselines.common.base_class import _UnvecWrapper +from stable_baselines.common.vec_env import VecEnvWrapper from .replay_buffer import HindsightExperienceReplayWrapper, KEY_TO_GOAL_STRATEGY from .utils import HERGoalEnvWrapper @@ -22,6 +23,8 @@ class HER(BaseRLModel): def __init__(self, policy, env, model_class, n_sampled_goal=4, goal_selection_strategy='future', *args, **kwargs): + assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper yet" + super().__init__(policy=policy, env=env, verbose=kwargs.get('verbose', 0), policy_base=None, requires_vec_env=False) @@ -69,6 +72,7 @@ def _create_replay_wrapper(self, env): wrapped_env=self.env) def set_env(self, env): + assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper yet" super().set_env(env) self._create_replay_wrapper(self.env) self.model.set_env(self.env) diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index 8729e4c5b2..1b9121ba0d 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -36,7 +36,7 @@ class GoalSelectionStrategy(Enum): class HindsightExperienceReplayWrapper(object): """ Wrapper around a replay buffer in order to use HER. - This implementation is close to the one found in https://github.com/NervanaSystems/coach/. + This implementation is inspired by to the one found in https://github.com/NervanaSystems/coach/. :param replay_buffer: (ReplayBuffer) :param n_sampled_goal: (int) The number of artificial transitions to generate for each actual transition @@ -162,7 +162,7 @@ def _store_episode(self): # Update the reward according to the new desired goal reward = self.env.compute_reward(goal, next_obs_dict['achieved_goal'], None) - # Can we ensure that done = reward == 0? + # Can we use achieved_goal == desired_goal? done = False # Transform back to ndarrays diff --git a/tests/test_her.py b/tests/test_her.py index 8ff8166f25..f145e7223a 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -6,7 +6,7 @@ from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY from stable_baselines.common.bit_flipping_env import BitFlippingEnv -from stable_baselines.common.vec_env import DummyVecEnv # , VecNormalize +from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize N_BITS = 10 @@ -46,8 +46,6 @@ def test_her(model_class, goal_selection_strategy): def test_model_manipulation(model_class, goal_selection_strategy): env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) env = DummyVecEnv([lambda: env]) - # NOTE: HER does not support VecEnvWrapper yet - # env = VecNormalize(env) model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy, verbose=0) @@ -58,6 +56,10 @@ def test_model_manipulation(model_class, goal_selection_strategy): model.save('./test_her') del model + # NOTE: HER does not support VecEnvWrapper yet + with pytest.raises(AssertionError): + model = HER.load('./test_her', env=VecNormalize(env)) + model = HER.load('./test_her') # Check that the model raises an error when the env From 872386997c48e69f0ffd502688ac8de02db71d4f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 5 May 2019 20:19:32 +0200 Subject: [PATCH 18/36] Update examples + add notebook for HER --- docs/guide/examples.rst | 50 ++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 6a8adb71e5..625d7ee7c7 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -13,27 +13,29 @@ notebooks: - `Monitor Training and Plotting`_ - `Atari Games`_ - `Breakout`_ (trained agent included) +- `Hindsight Experience Replay`_ .. _Getting Started: https://colab.research.google.com/drive/1_1H5bjWKYBVKbbs-Kj83dsfuZieDNcFU -.. _Training, Saving, Loading: https://colab.research.google.com/drive/1KoAQ1C_BNtGV3sVvZCnNZaER9rstmy0s +.. _Training, Saving, Loading: https://colab.research.google.com/drive/16QritJF5kgT3mtnODepld1fo5tFnFCoc .. _Multiprocessing: https://colab.research.google.com/drive/1ZzNFMUUi923foaVsYb4YjPy4mjKtnOxb .. _Monitor Training and Plotting: https://colab.research.google.com/drive/1L_IMo6v0a0ALK8nefZm6PqPSy0vZIWBT .. _Atari Games: https://colab.research.google.com/drive/1iYK11yDzOOqnrXi1Sfjm1iekZr4cxLaN .. _Breakout: https://colab.research.google.com/drive/14NwwEHwN4hdNgGzzySjxQhEVDff-zr7O +.. _Hindsight Experience Replay: https://colab.research.google.com/drive/1VDD0uLi8wjUXIqAdLKiK15XaEe0z2FOc .. |colab| image:: ../_static/img/colab.svg Basic Usage: Training, Saving, Loading -------------------------------------- -In the following example, we will train, save and load an A2C model on the Lunar Lander environment. +In the following example, we will train, save and load a DQN model on the Lunar Lander environment. .. image:: ../_static/img/try_it.png :scale: 30 % - :target: https://colab.research.google.com/drive/1KoAQ1C_BNtGV3sVvZCnNZaER9rstmy0s + :target: https://colab.research.google.com/drive/16QritJF5kgT3mtnODepld1fo5tFnFCoc -.. figure:: https://cdn-images-1.medium.com/max/960/1*W7X69nxINgZEcJEAyoHCVw.gif +.. figure:: https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif Lunar Lander Environment @@ -46,25 +48,21 @@ In the following example, we will train, save and load an A2C model on the Lunar import gym - from stable_baselines.common.policies import MlpPolicy - from stable_baselines.common.vec_env import DummyVecEnv - from stable_baselines import A2C + from stable_baselines import DQN - # Create and wrap the environment + # Create environment env = gym.make('LunarLander-v2') - env = DummyVecEnv([lambda: env]) - # Alternatively, you can directly use: - # model = A2C('MlpPolicy', 'LunarLander-v2', ent_coef=0.1, verbose=1) - model = A2C(MlpPolicy, env, ent_coef=0.1, verbose=1) + # Instantiate the agent + model = DQN('MlpPolicy', env, learning_rate=1e-3, prioritized_replay=True, verbose=1) # Train the agent - model.learn(total_timesteps=100000) + model.learn(total_timesteps=int(2e5)) # Save the agent - model.save("a2c_lunar") + model.save("dqn_lunar") del model # delete trained model to demonstrate loading # Load the trained agent - model = A2C.load("a2c_lunar") + model = DQN.load("dqn_lunar") # Enjoy trained agent obs = env.reset() @@ -152,12 +150,11 @@ If your callback returns False, training is aborted early. import numpy as np import matplotlib.pyplot as plt - from stable_baselines.ddpg.policies import MlpPolicy - from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv + from stable_baselines.ddpg.policies import LnMlpPolicy from stable_baselines.bench import Monitor from stable_baselines.results_plotter import load_results, ts2xy from stable_baselines import DDPG - from stable_baselines.ddpg.noise import AdaptiveParamNoiseSpec + from stable_baselines.ddpg import AdaptiveParamNoiseSpec best_mean_reward, n_steps = -np.inf, 0 @@ -171,7 +168,7 @@ If your callback returns False, training is aborted early. global n_steps, best_mean_reward # Print stats every 1000 calls if (n_steps + 1) % 1000 == 0: - # Evaluate policy performance + # Evaluate policy training performance x, y = ts2xy(load_results(log_dir), 'timesteps') if len(x) > 0: mean_reward = np.mean(y[-100:]) @@ -195,13 +192,14 @@ If your callback returns False, training is aborted early. # Create and wrap the environment env = gym.make('LunarLanderContinuous-v2') env = Monitor(env, log_dir, allow_early_resets=True) - env = DummyVecEnv([lambda: env]) # Add some param noise for exploration - param_noise = AdaptiveParamNoiseSpec(initial_stddev=0.2, desired_action_stddev=0.2) - model = DDPG(MlpPolicy, env, param_noise=param_noise, memory_limit=int(1e6), verbose=0) + param_noise = AdaptiveParamNoiseSpec(initial_stddev=0.1, desired_action_stddev=0.1) + # Because we use parameter noise, we should use a MlpPolicy with layer normalization + model = DDPG(LnMlpPolicy, env, param_noise=param_noise, verbose=0) # Train the agent - model.learn(total_timesteps=200000, callback=callback) + model.learn(total_timesteps=int(1e5), callback=callback) + Atari Games ----------- @@ -354,6 +352,12 @@ Hindsight Experience Replay (HER) For this example, we are using `Highway-Env `_ by `@eleurent `_. + +.. image:: ../_static/img/try_it.png + :scale: 30 % + :target: https://colab.research.google.com/drive/1VDD0uLi8wjUXIqAdLKiK15XaEe0z2FOc + + .. figure:: https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif The highway-parking-v0 environment. From 0be6f84db02e5e6f4811e4656542358dc16c58b1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 19 May 2019 16:32:02 +0200 Subject: [PATCH 19/36] Add random exploration for SAC and DDPG --- docs/misc/changelog.rst | 3 +++ stable_baselines/ddpg/ddpg.py | 17 +++++++++++++++-- stable_baselines/sac/sac.py | 13 ++++++++++--- tests/test_her.py | 4 +++- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1bfcb30228..d3382d0efa 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -20,6 +20,9 @@ Pre-Release 2.6.0a0 (WIP) - Removed ``get_available_gpus`` function which hadn't been used anywhere (@Pastafarianist) - Fixed path splitting in ``TensorboardWriter._get_latest_run_id()`` on Windows machines (@PatrickWalter214) - The parameter ``filter_size`` of the function ``conv`` in A2C utils now supports passing a list/tuple of two integers (height and width), in order to have non-squared kernel matrix. (@yutingsz) +- add ``random_exploration`` parameter for DDPG and SAC, it may be useful when using HER + DDPG/SAC + this hack was present in the original OpenAI Baselines DDPG + HER implementation. + Release 2.5.1 (2019-05-04) -------------------------- diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 97ee031130..2aa04bf778 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -174,6 +174,9 @@ class DDPG(OffPolicyRLModel): Use `buffer_size` instead. :param buffer_size: (int) the max number of transitions to store, size of the replay buffer + :param random_exploration: (float) Probability of taken a random action (as in an epsilon-greedy strategy) + This is not needed for DDPG normally but can help exploring when using HER + DDPG. + This hack was present in the original OpenAI Baselines repo (DDPG + HER) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance @@ -187,7 +190,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n normalize_observations=False, tau=0.001, batch_size=128, param_noise_adaption_interval=50, normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0., return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1., - render=False, render_eval=False, memory_limit=None, buffer_size=50000, + render=False, render_eval=False, memory_limit=None, buffer_size=50000, random_exploration=0.0, verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): @@ -233,6 +236,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n self.buffer_size = buffer_size self.tensorboard_log = tensorboard_log self.full_tensorboard_log = full_tensorboard_log + self.random_exploration = random_exploration # init self.graph = None @@ -853,7 +857,15 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # Execute next action. if rank == 0 and self.render: self.env.render() - new_obs, reward, done, info = self.env.step(action * np.abs(self.action_space.low)) + + # Randomly sample actions from a uniform distribution + # with a probabilty self.random_exploration (used in HER + DDPG) + if np.random.rand() < self.random_exploration: + rescaled_action = action = self.action_space.sample() + else: + rescaled_action = action * np.abs(self.action_space.low) + + new_obs, reward, done, info = self.env.step(rescaled_action) if writer is not None: ep_rew = np.array([reward]).reshape((1, -1)) @@ -1057,6 +1069,7 @@ def save(self, save_path): "reward_scale": self.reward_scale, "memory_limit": self.memory_limit, "buffer_size": self.buffer_size, + "random_exploration": self.random_exploration, "policy": self.policy, "n_envs": self.n_envs, "_vectorize_action": self._vectorize_action, diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 77b6888a05..485ac1d143 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -55,6 +55,9 @@ class SAC(OffPolicyRLModel): :param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto') :param action_noise: (ActionNoise) the action noise type (None by default), this can help for hard exploration problem. Cf DDPG for the different action noise type. + :param random_exploration: (float) Probability of taken a random action (as in an epsilon-greedy strategy) + This is not needed for SAC normally but can help exploring when using HER + SAC. + This hack was present in the original OpenAI Baselines repo (DDPG + HER) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance @@ -67,7 +70,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=5000 learning_starts=100, train_freq=1, batch_size=64, tau=0.005, ent_coef='auto', target_update_interval=1, gradient_steps=1, target_entropy='auto', action_noise=None, - verbose=0, tensorboard_log=None, + random_exploration=0.0, verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(SAC, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, @@ -90,6 +93,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=5000 self.gradient_steps = gradient_steps self.gamma = gamma self.action_noise = action_noise + self.random_exploration = random_exploration self.value_fn = None self.graph = None @@ -393,8 +397,10 @@ def learn(self, total_timesteps, callback=None, seed=None, # Before training starts, randomly sample actions # from a uniform distribution for better exploration. - # Afterwards, use the learned policy. - if self.num_timesteps < self.learning_starts: + # Afterwards, use the learned policy + # if random_exploration is set to 0 (normal setting) + if (self.num_timesteps < self.learning_starts + or np.random.rand() < self.random_exploration): action = self.env.action_space.sample() # No need to rescale when sampling random action rescaled_action = action @@ -538,6 +544,7 @@ def save(self, save_path): "policy": self.policy, "n_envs": self.n_envs, "action_noise": self.action_noise, + "random_exploration": self.random_exploration, "_vectorize_action": self._vectorize_action, "policy_kwargs": self.policy_kwargs } diff --git a/tests/test_her.py b/tests/test_her.py index f145e7223a..8d239be52e 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -36,8 +36,10 @@ def model_predict(model, env, n_steps, additional_check=None): def test_her(model_class, goal_selection_strategy): env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) + # Take random actions 10% of the time + kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC] else {} model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, - verbose=0) + verbose=0, **kwargs) model.learn(1000) From b208889586a95d09bbfbdb9adbcc7b2b205928ba Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 19 May 2019 16:35:05 +0200 Subject: [PATCH 20/36] Typo in docstring --- stable_baselines/ddpg/ddpg.py | 2 +- stable_baselines/sac/sac.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 2aa04bf778..392554f8d3 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -174,7 +174,7 @@ class DDPG(OffPolicyRLModel): Use `buffer_size` instead. :param buffer_size: (int) the max number of transitions to store, size of the replay buffer - :param random_exploration: (float) Probability of taken a random action (as in an epsilon-greedy strategy) + :param random_exploration: (float) Probability of taking a random action (as in an epsilon-greedy strategy) This is not needed for DDPG normally but can help exploring when using HER + DDPG. This hack was present in the original OpenAI Baselines repo (DDPG + HER) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 485ac1d143..78f4f451c7 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -55,7 +55,7 @@ class SAC(OffPolicyRLModel): :param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto') :param action_noise: (ActionNoise) the action noise type (None by default), this can help for hard exploration problem. Cf DDPG for the different action noise type. - :param random_exploration: (float) Probability of taken a random action (as in an epsilon-greedy strategy) + :param random_exploration: (float) Probability of taking a random action (as in an epsilon-greedy strategy) This is not needed for SAC normally but can help exploring when using HER + SAC. This hack was present in the original OpenAI Baselines repo (DDPG + HER) :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug @@ -401,9 +401,8 @@ def learn(self, total_timesteps, callback=None, seed=None, # if random_exploration is set to 0 (normal setting) if (self.num_timesteps < self.learning_starts or np.random.rand() < self.random_exploration): - action = self.env.action_space.sample() # No need to rescale when sampling random action - rescaled_action = action + rescaled_action = action = self.env.action_space.sample() else: action = self.policy_tf.step(obs[None], deterministic=False).flatten() # Add noise to the action (improve exploration, From 27699bf4e1aafd11f3e15f935dc92c65fb32c1cc Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 19 May 2019 16:44:18 +0200 Subject: [PATCH 21/36] Doc update: add fix for DDPG saved models --- docs/misc/changelog.rst | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d3382d0efa..b3eb629d39 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,11 +9,11 @@ For download links, please look at `Github release page = "2.6.0": + sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.deepq.replay_buffer + stable_baselines.deepq.replay_buffer.Memory = stable_baselines.deepq.replay_buffer.ReplayBuffer + + +We recommend you to save again the model afterward, so the fix won't be needed the next time the trained agent is loaded. + + + Release 2.5.1 (2019-05-04) -------------------------- From 87db166af6eaf23aba30b920115b3cf28477eed3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 May 2019 19:30:15 +0200 Subject: [PATCH 22/36] Test with reward offset --- stable_baselines/her/replay_buffer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index 1b9121ba0d..ce7273bc12 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -137,8 +137,9 @@ def _store_episode(self): for transition_idx, transition in enumerate(self.episode_transitions): obs_t, action, reward, obs_tp1, done = transition + reward_offset = 0.0 # Add to the replay buffer - self.replay_buffer.add(obs_t, action, reward, obs_tp1, done) + self.replay_buffer.add(obs_t, action, reward + reward_offset, obs_tp1, done) # We cannot sample a goal from the future in the last step of an episode if (transition_idx == len(self.episode_transitions) - 1 and @@ -169,4 +170,4 @@ def _store_episode(self): obs, next_obs = map(self.env.convert_dict_to_obs, (obs_dict, next_obs_dict)) # Add artificial transition to the replay buffer - self.replay_buffer.add(obs, action, reward, next_obs, done) + self.replay_buffer.add(obs, action, reward + reward_offset, next_obs, done) From 1a7e0906d2d57490ad050b4f9e6413bafea9811e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 May 2019 19:34:10 +0200 Subject: [PATCH 23/36] Add GoalEnvNormalize draft --- stable_baselines/common/vec_env/__init__.py | 2 +- .../common/vec_env/vec_normalize.py | 103 ++++++++++++++++++ stable_baselines/her/her.py | 8 +- 3 files changed, 109 insertions(+), 4 deletions(-) diff --git a/stable_baselines/common/vec_env/__init__.py b/stable_baselines/common/vec_env/__init__.py index b8597551d3..17ae72a9ac 100644 --- a/stable_baselines/common/vec_env/__init__.py +++ b/stable_baselines/common/vec_env/__init__.py @@ -4,5 +4,5 @@ from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack -from stable_baselines.common.vec_env.vec_normalize import VecNormalize +from stable_baselines.common.vec_env.vec_normalize import VecNormalize, GoalEnvVecNormalize from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder diff --git a/stable_baselines/common/vec_env/vec_normalize.py b/stable_baselines/common/vec_env/vec_normalize.py index 8275dbf5de..c77a7255a8 100644 --- a/stable_baselines/common/vec_env/vec_normalize.py +++ b/stable_baselines/common/vec_env/vec_normalize.py @@ -103,3 +103,106 @@ def load_running_average(self, path): for name in ['obs_rms', 'ret_rms']: with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: setattr(self, name, pickle.load(file_handler)) + + +class GoalEnvVecNormalize(VecEnvWrapper): + """ + A moving average, normalizing wrapper for vectorized environment. + has support for saving/loading moving average, + + :param venv: (VecEnv) the vectorized environment to wrap + :param training: (bool) Whether to update or not the moving average + :param norm_obs: (bool) Whether to normalize observation or not (default: True) + :param norm_reward: (bool) Whether to normalize rewards or not (default: True) + :param clip_obs: (float) Max absolute value for observation + :param clip_reward: (float) Max value absolute for discounted reward + :param gamma: (float) discount factor + :param epsilon: (float) To avoid division by zero + """ + + def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, + clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): + VecEnvWrapper.__init__(self, venv) + self.obs_rms = RunningMeanStd(shape=self.observation_space['observation'].shape) + self.goal_rms = RunningMeanStd(shape=self.observation_space['achieved_goal'].shape) + self.ret_rms = RunningMeanStd(shape=()) + self.clip_obs = clip_obs + self.clip_reward = clip_reward + # Returns: discounted rewards + self.ret = np.zeros(self.num_envs) + self.gamma = gamma + self.epsilon = epsilon + self.training = training + self.norm_obs = norm_obs + self.norm_reward = norm_reward + self.old_obs = np.array([]) + + def step_wait(self): + """ + Apply sequence of actions to sequence of environments + actions -> (observations, rewards, news) + + where 'news' is a boolean vector indicating whether each element is new. + """ + obs, rews, news, infos = self.venv.step_wait() + self.ret = self.ret * self.gamma + rews + self.old_obs = obs + obs = self._normalize_observation(obs) + if self.norm_reward: + if self.training: + self.ret_rms.update(self.ret) + rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) + self.ret[news] = 0 + return obs, rews, news, infos + + def _normalize_observation(self, obs): + """ + :param obs: (numpy tensor) + """ + if self.norm_obs: + if self.training: + self.obs_rms.update(obs['observation']) + self.goal_rms.update(obs['achieved_goal']) + for key, rms in zip(['observation', 'achieved_goal', 'desired_goal'], + [self.obs_rms, self.goal_rms, self.goal_rms]): + obs[key] = np.clip((obs[key] - rms.mean) / np.sqrt(rms.var + self.epsilon), -self.clip_obs, + self.clip_obs) + return obs + else: + return obs + + def get_original_obs(self): + """ + returns the unnormalized observation + + :return: (numpy float) + """ + return self.old_obs + + def reset(self): + """ + Reset all environments + """ + obs = self.venv.reset() + if len(np.array(obs).shape) == 1: # for when num_cpu is 1 + self.old_obs = [obs] + else: + self.old_obs = obs + self.ret = np.zeros(self.num_envs) + return self._normalize_observation(obs) + + def save_running_average(self, path): + """ + :param path: (str) path to log dir + """ + for rms, name in zip([self.obs_rms, self.goal_rms, self.ret_rms], ['obs_rms', 'goal_rms', 'ret_rms']): + with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: + pickle.dump(rms, file_handler) + + def load_running_average(self, path): + """ + :param path: (str) path to log dir + """ + for name in ['obs_rms', 'goal_rms', 'ret_rms']: + with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: + setattr(self, name, pickle.load(file_handler)) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 23a775e2ac..8eb383df99 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -3,7 +3,7 @@ from stable_baselines.common import BaseRLModel from stable_baselines.common import OffPolicyRLModel from stable_baselines.common.base_class import _UnvecWrapper -from stable_baselines.common.vec_env import VecEnvWrapper +from stable_baselines.common.vec_env import VecEnvWrapper, GoalEnvVecNormalize from .replay_buffer import HindsightExperienceReplayWrapper, KEY_TO_GOAL_STRATEGY from .utils import HERGoalEnvWrapper @@ -23,7 +23,8 @@ class HER(BaseRLModel): def __init__(self, policy, env, model_class, n_sampled_goal=4, goal_selection_strategy='future', *args, **kwargs): - assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper yet" + if isinstance(env, VecEnvWrapper): + assert isinstance(env, GoalEnvVecNormalize), "HER support only GoalEnvVecNormalize" super().__init__(policy=policy, env=env, verbose=kwargs.get('verbose', 0), policy_base=None, requires_vec_env=False) @@ -72,7 +73,8 @@ def _create_replay_wrapper(self, env): wrapped_env=self.env) def set_env(self, env): - assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper yet" + if isinstance(env, VecEnvWrapper): + assert isinstance(env, GoalEnvVecNormalize), "HER support only GoalEnvVecNormalize" super().set_env(env) self._create_replay_wrapper(self.env) self.model.set_env(self.env) From 7592bbd9bfe6a826b68eaf88e56d2008ced81fd9 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 23 May 2019 23:45:44 +0200 Subject: [PATCH 24/36] Remove GoalEnvNormalize --- stable_baselines/common/vec_env/__init__.py | 2 +- .../common/vec_env/vec_normalize.py | 103 ------------------ stable_baselines/her/her.py | 8 +- stable_baselines/her/replay_buffer.py | 6 +- 4 files changed, 7 insertions(+), 112 deletions(-) diff --git a/stable_baselines/common/vec_env/__init__.py b/stable_baselines/common/vec_env/__init__.py index 17ae72a9ac..b8597551d3 100644 --- a/stable_baselines/common/vec_env/__init__.py +++ b/stable_baselines/common/vec_env/__init__.py @@ -4,5 +4,5 @@ from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack -from stable_baselines.common.vec_env.vec_normalize import VecNormalize, GoalEnvVecNormalize +from stable_baselines.common.vec_env.vec_normalize import VecNormalize from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder diff --git a/stable_baselines/common/vec_env/vec_normalize.py b/stable_baselines/common/vec_env/vec_normalize.py index c77a7255a8..8275dbf5de 100644 --- a/stable_baselines/common/vec_env/vec_normalize.py +++ b/stable_baselines/common/vec_env/vec_normalize.py @@ -103,106 +103,3 @@ def load_running_average(self, path): for name in ['obs_rms', 'ret_rms']: with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: setattr(self, name, pickle.load(file_handler)) - - -class GoalEnvVecNormalize(VecEnvWrapper): - """ - A moving average, normalizing wrapper for vectorized environment. - has support for saving/loading moving average, - - :param venv: (VecEnv) the vectorized environment to wrap - :param training: (bool) Whether to update or not the moving average - :param norm_obs: (bool) Whether to normalize observation or not (default: True) - :param norm_reward: (bool) Whether to normalize rewards or not (default: True) - :param clip_obs: (float) Max absolute value for observation - :param clip_reward: (float) Max value absolute for discounted reward - :param gamma: (float) discount factor - :param epsilon: (float) To avoid division by zero - """ - - def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, - clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): - VecEnvWrapper.__init__(self, venv) - self.obs_rms = RunningMeanStd(shape=self.observation_space['observation'].shape) - self.goal_rms = RunningMeanStd(shape=self.observation_space['achieved_goal'].shape) - self.ret_rms = RunningMeanStd(shape=()) - self.clip_obs = clip_obs - self.clip_reward = clip_reward - # Returns: discounted rewards - self.ret = np.zeros(self.num_envs) - self.gamma = gamma - self.epsilon = epsilon - self.training = training - self.norm_obs = norm_obs - self.norm_reward = norm_reward - self.old_obs = np.array([]) - - def step_wait(self): - """ - Apply sequence of actions to sequence of environments - actions -> (observations, rewards, news) - - where 'news' is a boolean vector indicating whether each element is new. - """ - obs, rews, news, infos = self.venv.step_wait() - self.ret = self.ret * self.gamma + rews - self.old_obs = obs - obs = self._normalize_observation(obs) - if self.norm_reward: - if self.training: - self.ret_rms.update(self.ret) - rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) - self.ret[news] = 0 - return obs, rews, news, infos - - def _normalize_observation(self, obs): - """ - :param obs: (numpy tensor) - """ - if self.norm_obs: - if self.training: - self.obs_rms.update(obs['observation']) - self.goal_rms.update(obs['achieved_goal']) - for key, rms in zip(['observation', 'achieved_goal', 'desired_goal'], - [self.obs_rms, self.goal_rms, self.goal_rms]): - obs[key] = np.clip((obs[key] - rms.mean) / np.sqrt(rms.var + self.epsilon), -self.clip_obs, - self.clip_obs) - return obs - else: - return obs - - def get_original_obs(self): - """ - returns the unnormalized observation - - :return: (numpy float) - """ - return self.old_obs - - def reset(self): - """ - Reset all environments - """ - obs = self.venv.reset() - if len(np.array(obs).shape) == 1: # for when num_cpu is 1 - self.old_obs = [obs] - else: - self.old_obs = obs - self.ret = np.zeros(self.num_envs) - return self._normalize_observation(obs) - - def save_running_average(self, path): - """ - :param path: (str) path to log dir - """ - for rms, name in zip([self.obs_rms, self.goal_rms, self.ret_rms], ['obs_rms', 'goal_rms', 'ret_rms']): - with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: - pickle.dump(rms, file_handler) - - def load_running_average(self, path): - """ - :param path: (str) path to log dir - """ - for name in ['obs_rms', 'goal_rms', 'ret_rms']: - with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: - setattr(self, name, pickle.load(file_handler)) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 8eb383df99..b125cc032e 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -3,7 +3,7 @@ from stable_baselines.common import BaseRLModel from stable_baselines.common import OffPolicyRLModel from stable_baselines.common.base_class import _UnvecWrapper -from stable_baselines.common.vec_env import VecEnvWrapper, GoalEnvVecNormalize +from stable_baselines.common.vec_env import VecEnvWrapper from .replay_buffer import HindsightExperienceReplayWrapper, KEY_TO_GOAL_STRATEGY from .utils import HERGoalEnvWrapper @@ -23,8 +23,7 @@ class HER(BaseRLModel): def __init__(self, policy, env, model_class, n_sampled_goal=4, goal_selection_strategy='future', *args, **kwargs): - if isinstance(env, VecEnvWrapper): - assert isinstance(env, GoalEnvVecNormalize), "HER support only GoalEnvVecNormalize" + assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper" super().__init__(policy=policy, env=env, verbose=kwargs.get('verbose', 0), policy_base=None, requires_vec_env=False) @@ -73,8 +72,7 @@ def _create_replay_wrapper(self, env): wrapped_env=self.env) def set_env(self, env): - if isinstance(env, VecEnvWrapper): - assert isinstance(env, GoalEnvVecNormalize), "HER support only GoalEnvVecNormalize" + assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper" super().set_env(env) self._create_replay_wrapper(self.env) self.model.set_env(self.env) diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index ce7273bc12..82f61a7f49 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -137,9 +137,9 @@ def _store_episode(self): for transition_idx, transition in enumerate(self.episode_transitions): obs_t, action, reward, obs_tp1, done = transition - reward_offset = 0.0 + # Add to the replay buffer - self.replay_buffer.add(obs_t, action, reward + reward_offset, obs_tp1, done) + self.replay_buffer.add(obs_t, action, reward, obs_tp1, done) # We cannot sample a goal from the future in the last step of an episode if (transition_idx == len(self.episode_transitions) - 1 and @@ -170,4 +170,4 @@ def _store_episode(self): obs, next_obs = map(self.env.convert_dict_to_obs, (obs_dict, next_obs_dict)) # Add artificial transition to the replay buffer - self.replay_buffer.add(obs, action, reward + reward_offset, next_obs, done) + self.replay_buffer.add(obs, action, reward, next_obs, done) From 730b1719dbd34faa34a57fad94d05e55ca2eb9bd Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 1 Jun 2019 00:15:37 +0200 Subject: [PATCH 25/36] Fix typo --- stable_baselines/a2c/a2c.py | 2 +- stable_baselines/ppo2/ppo2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index e8618f0712..27d2bc1656 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -23,7 +23,7 @@ class A2C(ActorCriticRLModel): :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param vf_coef: (float) Value function coefficient for the loss calculation - :param ent_coef: (float) Entropy coefficient for the loss caculation + :param ent_coef: (float) Entropy coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping :param learning_rate: (float) The learning rate :param alpha: (float) RMSProp decay parameter (default: 0.99) diff --git a/stable_baselines/ppo2/ppo2.py b/stable_baselines/ppo2/ppo2.py index eb009cee78..12c29324de 100644 --- a/stable_baselines/ppo2/ppo2.py +++ b/stable_baselines/ppo2/ppo2.py @@ -24,7 +24,7 @@ class PPO2(ActorCriticRLModel): :param gamma: (float) Discount factor :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param ent_coef: (float) Entropy coefficient for the loss caculation + :param ent_coef: (float) Entropy coefficient for the loss calculation :param learning_rate: (float or callable) The learning rate, it can be a function :param vf_coef: (float) Value function coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping From 635c7d059079d8739d1de1e4573a95fcc4948c8f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 1 Jun 2019 13:36:09 +0200 Subject: [PATCH 26/36] Bug fix for HER + VecEnv --- stable_baselines/her/utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index 00524e7abd..5c391d26f9 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import numpy as np from gym import spaces @@ -21,6 +23,13 @@ def __init__(self, env): # (current limitation of the wrapper) # TODO: check when dim > 1 + # Sanity check because we are doing dict to array operations + subspaces_keys = list(env.observation_space.spaces.keys()) + desired_keys = ['achieved_goal', 'desired_goal', 'observation'] + assert subspaces_keys == desired_keys,\ + "The keys of the GoalEnv must be ordered"\ + "in the following order:{} != {}".format(desired_keys, subspaces_keys) + goal_space_shape = env.observation_space.spaces['achieved_goal'].shape self.obs_dim = env.observation_space.spaces['observation'].shape[0] self.goal_dim = goal_space_shape[0] @@ -60,13 +69,13 @@ def convert_obs_to_dict(self, observations): Inverse operation of convert_dict_to_obs :param observations: (np.ndarray) - :return: (dict) + :return: (OrderedDict) """ - return { - 'observation': observations[:self.obs_dim], - 'achieved_goal': observations[self.obs_dim:self.obs_dim + self.goal_dim], - 'desired_goal': observations[self.obs_dim + self.goal_dim:], - } + return OrderedDict([ + ('achieved_goal', observations[:self.goal_dim]), + ('desired_goal', observations[self.goal_dim:2 * self.goal_dim]), + ('observation', observations[-self.obs_dim:]), + ]) def step(self, action): obs, reward, done, info = self.env.step(action) From bf363adffd64c132f5921638901cc1a004780f20 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 1 Jun 2019 13:47:22 +0200 Subject: [PATCH 27/36] Fix HER test env --- stable_baselines/common/bit_flipping_env.py | 26 +++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py index 768b7d0820..5574a4df02 100644 --- a/stable_baselines/common/bit_flipping_env.py +++ b/stable_baselines/common/bit_flipping_env.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import numpy as np from gym import GoalEnv, spaces @@ -18,9 +20,9 @@ def __init__(self, n_bits=10, continuous=False, max_steps=None): # The achieved goal is determined by the current state # here, it is a special where they are equal self.observation_space = spaces.Dict({ - 'observation': spaces.MultiBinary(n_bits), 'achieved_goal': spaces.MultiBinary(n_bits), - 'desired_goal': spaces.MultiBinary(n_bits) + 'desired_goal': spaces.MultiBinary(n_bits), + 'observation': spaces.MultiBinary(n_bits) }) if continuous: self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) @@ -38,22 +40,22 @@ def __init__(self, n_bits=10, continuous=False, max_steps=None): def reset(self): self.current_step = 0 self.state = self.observation_space.spaces['observation'].sample() - return { - 'observation': self.state.copy(), - 'achieved_goal': self.state.copy(), - 'desired_goal': self.desired_goal.copy() - } + return OrderedDict([ + ('achieved_goal', self.state.copy()), + ('desired_goal', self.desired_goal.copy()), + ('observation', self.state.copy()) + ]) def step(self, action): if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: self.state[action] = 1 - self.state[action] - obs = { - 'observation': self.state.copy(), - 'achieved_goal': self.state.copy(), - 'desired_goal': self.desired_goal.copy() - } + obs = OrderedDict([ + ('achieved_goal', self.state.copy()), + ('desired_goal', self.desired_goal.copy()), + ('observation', self.state.copy()) + ]) reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) done = (obs['achieved_goal'] == obs['desired_goal']).all() self.current_step += 1 From ccbc5c74798c9d7e92a7550d745c8a62e54bfc94 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 1 Jun 2019 16:45:29 +0200 Subject: [PATCH 28/36] Fixed key order --- stable_baselines/common/bit_flipping_env.py | 12 ++++++------ stable_baselines/her/utils.py | 19 ++++++++----------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py index 5574a4df02..e4c412c163 100644 --- a/stable_baselines/common/bit_flipping_env.py +++ b/stable_baselines/common/bit_flipping_env.py @@ -20,9 +20,9 @@ def __init__(self, n_bits=10, continuous=False, max_steps=None): # The achieved goal is determined by the current state # here, it is a special where they are equal self.observation_space = spaces.Dict({ + 'observation': spaces.MultiBinary(n_bits), 'achieved_goal': spaces.MultiBinary(n_bits), - 'desired_goal': spaces.MultiBinary(n_bits), - 'observation': spaces.MultiBinary(n_bits) + 'desired_goal': spaces.MultiBinary(n_bits) }) if continuous: self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) @@ -41,9 +41,9 @@ def reset(self): self.current_step = 0 self.state = self.observation_space.spaces['observation'].sample() return OrderedDict([ + ('observation', self.state.copy()), ('achieved_goal', self.state.copy()), - ('desired_goal', self.desired_goal.copy()), - ('observation', self.state.copy()) + ('desired_goal', self.desired_goal.copy()) ]) def step(self, action): @@ -52,9 +52,9 @@ def step(self, action): else: self.state[action] = 1 - self.state[action] obs = OrderedDict([ + ('observation', self.state.copy()), ('achieved_goal', self.state.copy()), - ('desired_goal', self.desired_goal.copy()), - ('observation', self.state.copy()) + ('desired_goal', self.desired_goal.copy()) ]) reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) done = (obs['achieved_goal'] == obs['desired_goal']).all() diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index 5c391d26f9..fd6ecdd240 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -3,6 +3,10 @@ import numpy as np from gym import spaces +# Important: gym mixes up ordered and unordered keys +# and the Dict space may return a different order of keys that the actual one +KEY_ORDER = ['observation', 'achieved_goal', 'desired_goal'] + class HERGoalEnvWrapper(object): """ @@ -23,13 +27,6 @@ def __init__(self, env): # (current limitation of the wrapper) # TODO: check when dim > 1 - # Sanity check because we are doing dict to array operations - subspaces_keys = list(env.observation_space.spaces.keys()) - desired_keys = ['achieved_goal', 'desired_goal', 'observation'] - assert subspaces_keys == desired_keys,\ - "The keys of the GoalEnv must be ordered"\ - "in the following order:{} != {}".format(desired_keys, subspaces_keys) - goal_space_shape = env.observation_space.spaces['achieved_goal'].shape self.obs_dim = env.observation_space.spaces['observation'].shape[0] self.goal_dim = goal_space_shape[0] @@ -62,7 +59,7 @@ def convert_dict_to_obs(obs_dict): """ # Note: achieved goal is not removed from the observation # this is helpful to have a revertible transformation - return np.concatenate([obs for obs in obs_dict.values()]) + return np.concatenate([obs_dict[key] for key in KEY_ORDER]) def convert_obs_to_dict(self, observations): """ @@ -72,9 +69,9 @@ def convert_obs_to_dict(self, observations): :return: (OrderedDict) """ return OrderedDict([ - ('achieved_goal', observations[:self.goal_dim]), - ('desired_goal', observations[self.goal_dim:2 * self.goal_dim]), - ('observation', observations[-self.obs_dim:]), + ('observation', observations[:self.obs_dim]), + ('achieved_goal', observations[self.obs_dim:self.obs_dim + self.goal_dim]), + ('desired_goal', observations[self.obs_dim + self.goal_dim:]), ]) def step(self, action): From e1e344b4d65d7d1a9a88584cf010b4ee141d3517 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 2 Jun 2019 20:51:57 +0200 Subject: [PATCH 29/36] Add support for discrete obs space --- stable_baselines/common/bit_flipping_env.py | 76 +++++++++++++++------ stable_baselines/gail/adversary.py | 2 +- stable_baselines/her/utils.py | 39 +++++++---- tests/test_her.py | 6 +- 4 files changed, 86 insertions(+), 37 deletions(-) diff --git a/stable_baselines/common/bit_flipping_env.py b/stable_baselines/common/bit_flipping_env.py index e4c412c163..4b2617a3a8 100644 --- a/stable_baselines/common/bit_flipping_env.py +++ b/stable_baselines/common/bit_flipping_env.py @@ -12,23 +12,40 @@ class BitFlippingEnv(GoalEnv): then the ith bit will be flipped. :param n_bits: (int) Number of bits to flip - :param continuous: (bool) Wether to use the continuous version or not - :param max_steps: (int) Max number of steps, by defaults, equal to n_bits + :param continuous: (bool) Whether to use the continuous actions version or not, + by default, it uses the discrete one + :param max_steps: (int) Max number of steps, by default, equal to n_bits + :param discrete_obs_space: (bool) Whether to use the discrete observation + version or not, by default, it uses the MultiBinary one """ - def __init__(self, n_bits=10, continuous=False, max_steps=None): + def __init__(self, n_bits=10, continuous=False, max_steps=None, + discrete_obs_space=False): super(BitFlippingEnv, self).__init__() # The achieved goal is determined by the current state # here, it is a special where they are equal - self.observation_space = spaces.Dict({ - 'observation': spaces.MultiBinary(n_bits), - 'achieved_goal': spaces.MultiBinary(n_bits), - 'desired_goal': spaces.MultiBinary(n_bits) - }) + if discrete_obs_space: + # In the discrete case, the agent act on the binary + # representation of the observation + self.observation_space = spaces.Dict({ + 'observation': spaces.Discrete(2 ** n_bits - 1), + 'achieved_goal': spaces.Discrete(2 ** n_bits - 1), + 'desired_goal': spaces.Discrete(2 ** n_bits - 1) + }) + else: + self.observation_space = spaces.Dict({ + 'observation': spaces.MultiBinary(n_bits), + 'achieved_goal': spaces.MultiBinary(n_bits), + 'desired_goal': spaces.MultiBinary(n_bits) + }) + + self.obs_space = spaces.MultiBinary(n_bits) + if continuous: self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) else: self.action_space = spaces.Discrete(n_bits) self.continuous = continuous + self.discrete_obs_space = discrete_obs_space self.state = None self.desired_goal = np.ones((n_bits,)) if max_steps is None: @@ -37,27 +54,44 @@ def __init__(self, n_bits=10, continuous=False, max_steps=None): self.current_step = 0 self.reset() - def reset(self): - self.current_step = 0 - self.state = self.observation_space.spaces['observation'].sample() + def convert_if_needed(self, state): + """ + Convert to discrete space if needed. + + :param state: (np.ndarray) + :return: (np.ndarray or int) + """ + if self.discrete_obs_space: + # The internal state is the binary representation of the + # observed one + return int(sum([state[i] * 2**i for i in range(len(state))])) + return state + + def _get_obs(self): + """ + Helper to create the observation. + + :return: (OrderedDict) + """ return OrderedDict([ - ('observation', self.state.copy()), - ('achieved_goal', self.state.copy()), - ('desired_goal', self.desired_goal.copy()) + ('observation', self.convert_if_needed(self.state.copy())), + ('achieved_goal', self.convert_if_needed(self.state.copy())), + ('desired_goal', self.convert_if_needed(self.desired_goal.copy())) ]) + def reset(self): + self.current_step = 0 + self.state = self.obs_space.sample() + return self._get_obs() + def step(self, action): if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: self.state[action] = 1 - self.state[action] - obs = OrderedDict([ - ('observation', self.state.copy()), - ('achieved_goal', self.state.copy()), - ('desired_goal', self.desired_goal.copy()) - ]) + obs = self._get_obs() reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) - done = (obs['achieved_goal'] == obs['desired_goal']).all() + done = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps info = {'is_success': done} @@ -66,6 +100,8 @@ def step(self, action): def compute_reward(self, achieved_goal, desired_goal, _info): # Deceptive reward: it is positive only when the goal is achieved + if self.discrete_obs_space: + return 0 if achieved_goal == desired_goal else -1 return 0 if (achieved_goal == desired_goal).all() else -1 def render(self, mode='human'): diff --git a/stable_baselines/gail/adversary.py b/stable_baselines/gail/adversary.py index c61c648a1b..ade1d977c1 100644 --- a/stable_baselines/gail/adversary.py +++ b/stable_baselines/gail/adversary.py @@ -43,7 +43,7 @@ def __init__(self, observation_space, action_space, hidden_size, :param hidden_size: ([int]) the hidden dimension for the MLP :param entcoeff: (float) the entropy loss weight :param scope: (str) tensorflow variable scope - :param normalize: (bool) Wether to normalize the reward or not + :param normalize: (bool) Whether to normalize the reward or not """ # TODO: support images properly (using a CNN) self.scope = scope diff --git a/stable_baselines/her/utils.py b/stable_baselines/her/utils.py index fd6ecdd240..f1bf696f0a 100644 --- a/stable_baselines/her/utils.py +++ b/stable_baselines/her/utils.py @@ -23,21 +23,28 @@ def __init__(self, env): self.metadata = self.env.metadata self.action_space = env.action_space self.spaces = list(env.observation_space.spaces.values()) - # TODO: check that all spaces are of the same type + # Check that all spaces are of the same type # (current limitation of the wrapper) - # TODO: check when dim > 1 + space_types = [type(env.observation_space.spaces[key]) for key in KEY_ORDER] + assert len(set(space_types)) == 1, "The spaces for goal and observation"\ + " must be of the same type" - goal_space_shape = env.observation_space.spaces['achieved_goal'].shape - self.obs_dim = env.observation_space.spaces['observation'].shape[0] - self.goal_dim = goal_space_shape[0] - total_dim = self.obs_dim + 2 * self.goal_dim - - if len(goal_space_shape) == 2: - assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet" + if isinstance(self.spaces[0], spaces.Discrete): + self.obs_dim = 1 + self.goal_dim = 1 else: - assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet" + goal_space_shape = env.observation_space.spaces['achieved_goal'].shape + self.obs_dim = env.observation_space.spaces['observation'].shape[0] + self.goal_dim = goal_space_shape[0] + + if len(goal_space_shape) == 2: + assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet" + else: + assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet" + if isinstance(self.spaces[0], spaces.MultiBinary): + total_dim = self.obs_dim + 2 * self.goal_dim self.observation_space = spaces.MultiBinary(total_dim) elif isinstance(self.spaces[0], spaces.Box): @@ -46,19 +53,23 @@ def __init__(self, env): self.observation_space = spaces.Box(lows, highs, dtype=np.float32) elif isinstance(self.spaces[0], spaces.Discrete): - pass + dimensions = [env.observation_space.spaces[key].n for key in KEY_ORDER] + self.observation_space = spaces.MultiDiscrete(dimensions) else: - raise NotImplementedError() + raise NotImplementedError("{} space is not supported".format(type(self.spaces[0]))) + - @staticmethod - def convert_dict_to_obs(obs_dict): + def convert_dict_to_obs(self, obs_dict): """ :param obs_dict: (dict) :return: (np.ndarray) """ # Note: achieved goal is not removed from the observation # this is helpful to have a revertible transformation + if isinstance(self.observation_space, spaces.MultiDiscrete): + # Special case for multidiscrete + return np.concatenate([[int(obs_dict[key])] for key in KEY_ORDER]) return np.concatenate([obs_dict[key] for key in KEY_ORDER]) def convert_obs_to_dict(self, observations): diff --git a/tests/test_her.py b/tests/test_her.py index 8d239be52e..58a38e67ac 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -33,8 +33,10 @@ def model_predict(model, env, n_steps, additional_check=None): @pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy)) @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) -def test_her(model_class, goal_selection_strategy): - env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS) +@pytest.mark.parametrize('discrete_obs_space', [False, True]) +def test_her(model_class, goal_selection_strategy, discrete_obs_space): + env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], + max_steps=N_BITS, discrete_obs_space=discrete_obs_space) # Take random actions 10% of the time kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC] else {} From 096f0452b738eb8b7fc0e0106af89a9d2fc1ed11 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 2 Jun 2019 20:54:22 +0200 Subject: [PATCH 30/36] Update doc about reproducing experiments --- docs/modules/her.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 505d2af87f..7b8cb8ed05 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -12,7 +12,9 @@ HER is a method wrapper that works with Off policy methods (DQN, SAC and DDPG fo .. note:: - HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines + HER was re-implemented from scratch in Stable-Baselines compared to the original OpenAI baselines. + If you want to reproduce results from the paper, please use the rl baselines zoo + in order to have the correct hyperparameters and at least 8 MPI workers with DDPG. .. warning:: From 7688838d44cd7cd89f4f1e8eb7a6ae883b917c90 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 2 Jun 2019 21:32:31 +0200 Subject: [PATCH 31/36] Update doc: DDPG supports multiprocessing with MPI --- README.md | 2 +- docs/guide/algos.rst | 2 +- docs/guide/custom_env.rst | 5 ++--- docs/modules/ddpg.rst | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 475603e081..5d28417ae2 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ All the following examples can be executed online using Google colab notebooks: | A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | ACER | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | ACKTR | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | -| DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | +| DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: (4)| | DQN | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: | | GAIL (2) | :heavy_check_mark: | :x: | :heavy_check_mark: |:heavy_check_mark:| :x: | :x: | :heavy_check_mark: (4) | | HER (3) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark:| :x: | diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index f2967cdfe7..1a397c746f 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -18,7 +18,7 @@ Name Refactored [#f1]_ Recurrent ``Box`` ``Discrete`` Multi P A2C ✔️ ✔️ ✔️ ✔️ ✔️ ACER ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️ ACKTR ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️ -DDPG ✔️ ❌ ✔️ ❌ ❌ +DDPG ✔️ ❌ ✔️ ❌ ✔️ [#f3]_ DQN ✔️ ❌ ❌ ✔️ ❌ HER ✔️ ❌ ✔️ ✔️ ❌ GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️ ✔️ [#f3]_ diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 8825c4f7aa..02c8651c04 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -8,9 +8,8 @@ That is to say, your environment must implement the following methods (and inher .. note:: - - If you are using images as input, the input values must be in [0, 255] as the observation - is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. + If you are using images as input, the input values must be in [0, 255] as the observation + is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 07041cead8..bb86b6e593 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -36,7 +36,7 @@ Can I use? ---------- - Recurrent policies: ❌ -- Multi processing: ❌ +- Multi processing: ✔️ (using MPI) - Gym spaces: From cd1822542959020a63768a88aaecb3fc21f5e191 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Jun 2019 00:05:46 +0200 Subject: [PATCH 32/36] Fix for new abstract method --- stable_baselines/her/her.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index b125cc032e..de7f079c72 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -80,6 +80,9 @@ def set_env(self, env): def get_env(self): return self.env + def get_parameter_list(self): + return self.model.get_parameter_list() + def __getattr__(self, attr): """ Wrap the RL model. From 65ef63150b9173e343a3ac68448d30bbb9d4dcb5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Jun 2019 00:07:09 +0200 Subject: [PATCH 33/36] Update changelog --- docs/misc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 92e2e637df..890c572639 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -24,7 +24,7 @@ Pre-Release 2.6.0a0 (WIP) this hack was present in the original OpenAI Baselines DDPG + HER implementation. - fixed a bug where initial learning rate is logged instead of its placeholder in ``A2C.setup_model`` (@sc420) - fixed a bug where number of timesteps is incorrectly updated and logged in ``A2C.learn`` and ``A2C._train_step`` (@sc420) -- added ``load_parameters`` and ``get_parameters`` for most learning algorithms. +- added ``load_parameters`` and ``get_parameters`` to base RL class. With these methods, users are able to load and get parameters to/from existing model, without touching tensorflow. (@Miffyli) - **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli) From 84af1666fcd7173c568565285c09913870ad4937 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 4 Jun 2019 09:14:11 +0200 Subject: [PATCH 34/36] Fix custom policy example --- docs/guide/custom_policy.rst | 14 ++++++-------- docs/misc/changelog.rst | 5 +++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 89b007a301..45b2a4310c 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -216,19 +216,18 @@ If your task requires even more granular control over the policy architecture, y value_fn = tf.layers.dense(vf_h, 1, name='vf') vf_latent = vf_h - self.proba_distribution, self.policy, self.q_value = \ + self._proba_distribution, self._policy, self.q_value = \ self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01) - self.value_fn = value_fn - self.initial_state = None + self._value_fn = value_fn self._setup_init() def step(self, obs, state=None, mask=None, deterministic=False): if deterministic: - action, value, neglogp = self.sess.run([self.deterministic_action, self._value, self.neglogp], + action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp], {self.obs_ph: obs}) else: - action, value, neglogp = self.sess.run([self.action, self._value, self.neglogp], + action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp], {self.obs_ph: obs}) return action, value, self.initial_state, neglogp @@ -236,12 +235,11 @@ If your task requires even more granular control over the policy architecture, y return self.sess.run(self.policy_proba, {self.obs_ph: obs}) def value(self, obs, state=None, mask=None): - return self.sess.run(self._value, {self.obs_ph: obs}) + return self.sess.run(self.value_flat, {self.obs_ph: obs}) # Create and wrap the environment - env = gym.make('Breakout-v0') - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: gym.make('Breakout-v0')]) model = A2C(CustomPolicy, env, verbose=1) # Train the agent diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 890c572639..d3767c253c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -54,7 +54,7 @@ Release 2.5.1 (2019-05-04) **Bug fixes + improvements in the VecEnv** -**Warning: breaking change when using custom recurrent policies** +**Warning: breaking changes when using custom policies** - doc update (fix example of result plotter + improve doc) - fixed logger issues when stdout lacks ``read`` function @@ -73,7 +73,8 @@ Release 2.5.1 (2019-05-04) to exactly one of the nested instances i.e. it must be unambiguous. (@kantneel) - fixed bug where result plotter would crash on very short runs (@Pastafarianist) - added option to not trim output of result plotter by number of timesteps (@Pastafarianist) -- clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``. +- clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``, + and most placeholders were made private: e.g. ``self.value_fn`` is now ``self._value_fn`` - support for custom stateful policies. - fixed episode length recording in ``trpo_mpi.utils.traj_segment_generator`` (@GerardMaggiolino) From e2408eb306aee9579109f2c144bedc79ebea04a3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 4 Jun 2019 09:15:48 +0200 Subject: [PATCH 35/36] Add replay_wrapper to base OffPolicy class --- stable_baselines/common/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 8fd7aacf9c..d1bde6542b 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -703,7 +703,7 @@ def setup_model(self): @abstractmethod def learn(self, total_timesteps, callback=None, seed=None, - log_interval=100, tb_log_name="run", reset_num_timesteps=True): + log_interval=100, tb_log_name="run", reset_num_timesteps=True, replay_wrapper=None): pass @abstractmethod From 6ed497dcfb2c3996250b6d2e18551e05896f0a72 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 4 Jun 2019 09:26:13 +0200 Subject: [PATCH 36/36] Fix reimport --- stable_baselines/common/base_class.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index d1bde6542b..a0637d8d41 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from collections import OrderedDict import os import glob import warnings