Skip to content

Commit

Permalink
add posibility to add stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Elbarmo committed Apr 5, 2019
1 parent a3d2ca3 commit 5205590
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
10 changes: 5 additions & 5 deletions baselines/ppo2/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def mara_mlp():
log_interval = 1,
ent_coef = 0.0,
lr = lambda f: 3e-3 * f,
cliprange = 0.2,
vf_coef = 0.5,
cliprange = 0.25,
vf_coef = 1,
max_grad_norm = 0.5,
seed = 0,
value_network = 'copy',
Expand All @@ -51,7 +51,7 @@ def mara_mlp():
# env_name = 'MARACollisionOrient-v0',
transfer_path = None,
# transfer_path = '/tmp/ros2learn/MARA-v0/ppo2_mlp/2019-02-19_12h47min/checkpoints/best',
trained_path = '/tmp/ros2learn/MARAOrient-v0/ppo2_mlp/2019-03-26_14h27min/checkpoints/best'
trained_path = '/tmp/ros2learn/MARA-v0/ppo2_mlp/2019-04-02_13h18min/checkpoints/best'
)

def mara_lstm():
Expand Down Expand Up @@ -79,8 +79,8 @@ def mara_lstm():
network = 'lstm',
total_timesteps = 1e8,
save_interval = 10,
env_name = 'MARACollisionOrientRandomTarget-v0',
num_envs = 2,
env_name = 'MARA-v0',
num_envs = 4,
transfer_path = None,
# transfer_path = '/tmp/ros2learn/MARACollisionOrientRandomTarget-v0/ppo2_lstm/checkpoints/00090',
trained_path = '/tmp/ros2learn/MARACollisionOrientRandomTarget-v0/ppo2_lstm/checkpoints/00090'
Expand Down
8 changes: 7 additions & 1 deletion baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
cliprangenow = cliprange(frac)
# Get minibatch
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632

if eval_env is not None:
eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632

Expand Down Expand Up @@ -194,7 +195,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
logger.logkv("fps", fps)
logger.logkv("explained_variance", float(ev))
mean_rewbuffer = safemean([epinfo['r'] for epinfo in epinfobuf])
logger.logkv('eprewmean', mean_rewbuffer)
logger.logkv('eprewmean_smooth', mean_rewbuffer)
logger.logkv('eprewsem', np.std([epinfo['r'] for epinfo in epinfobuf]))
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
if eval_env is not None:
Expand All @@ -203,6 +204,11 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
logger.logkv('time_elapsed', tnow - tfirststart)
for (lossval, lossname) in zip(lossvals, model.loss_names):
logger.logkv(lossname, lossval)

key_set = [key for key in list(epinfobuf)[0].keys() if key not in ["r", "l", "t"]]
for key in key_set:
logger.logkv(key, list(epinfobuf)[-1][key])

if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
logger.dumpkvs()

Expand Down
8 changes: 4 additions & 4 deletions baselines/ppo2/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
mb_states = self.states
epinfos = []
maybeepinfo = []
# For n in range number of steps
for _ in range(self.nsteps):
# Given observations, get action value and neglopacs
Expand All @@ -37,8 +38,9 @@ def run(self):
# Infos contains a ton of useful informations
self.obs[:], rewards, self.dones, infos = self.env.step(actions)
for info in infos:
maybeepinfo = info.get('episode')
if maybeepinfo: epinfos.append(maybeepinfo)
for key in info.keys():
maybeepinfo.append(info.get(key))
if maybeepinfo: epinfos.append({key:dict[key] for dict in maybeepinfo for key in dict})
mb_rewards.append(rewards)
#batch of steps to batch of rollouts
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
Expand Down Expand Up @@ -72,5 +74,3 @@ def sf01(arr):
"""
s = arr.shape
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])


0 comments on commit 5205590

Please sign in to comment.