Skip to content

Commit

Permalink
Merge internal
Browse files Browse the repository at this point in the history
  • Loading branch information
nzlz committed Apr 11, 2019
2 parents 2d05ec0 + 6cf3d43 commit bde6006
Show file tree
Hide file tree
Showing 51 changed files with 776 additions and 510 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ install:

script:
- flake8 . --show-source --statistics
- docker run baselines-test pytest -v --forked .
- docker run baselines-test pytest -v .
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ python -m baselines.run --alg=ppo2 --env=Humanoid-v2 --network=mlp --num_timeste
will set entropy coefficient to 0.1, and construct fully connected network with 3 layers with 32 hidden units in each, and create a separate network for value function estimation (so that its parameters are not shared with the policy network, but the structure is the same)
See docstrings in [common/models.py](baselines/common/models.py) for description of network parameters for each type of model, and
docstring for [baselines/ppo2/ppo2.py/learn()](baselines/ppo2/ppo2.py#L152) for the description of the ppo2 hyperparamters.
docstring for [baselines/ppo2/ppo2.py/learn()](baselines/ppo2/ppo2.py#L152) for the description of the ppo2 hyperparameters.
### Example 2. DQN on Atari
DQN with Atari is at this point a classics of benchmarks. To run the baselines implementation of DQN on Atari Pong:
Expand Down
8 changes: 7 additions & 1 deletion baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.runner import Runner
from baselines.ppo2.ppo2 import safemean
from collections import deque

from tensorflow import losses

Expand Down Expand Up @@ -195,6 +197,7 @@ def learn(

# Instantiate the runner object
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)

# Calculate the batch_size
nbatch = nenvs*nsteps
Expand All @@ -204,7 +207,8 @@ def learn(

for update in range(1, total_timesteps//nbatch+1):
# Get mini batch of experiences
obs, states, rewards, masks, actions, values = runner.run()
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)

policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
nseconds = time.time()-tstart
Expand All @@ -221,5 +225,7 @@ def learn(
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(ev))
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
logger.dump_tabular()
return model
8 changes: 6 additions & 2 deletions baselines/a2c/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def run(self):
# We initialize the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
mb_states = self.states
epinfos = []
for n in range(self.nsteps):
# Given observations, take action and value (V(s))
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
Expand All @@ -34,7 +35,10 @@ def run(self):
mb_dones.append(self.dones)

# Take actions in env and look the results
obs, rewards, dones, _ = self.env.step(actions)
obs, rewards, dones, infos = self.env.step(actions)
for info in infos:
maybeepinfo = info.get('episode')
if maybeepinfo: epinfos.append(maybeepinfo)
self.states = states
self.dones = dones
self.obs = obs
Expand Down Expand Up @@ -69,4 +73,4 @@ def run(self):
mb_rewards = mb_rewards.flatten()
mb_values = mb_values.flatten()
mb_masks = mb_masks.flatten()
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
22 changes: 14 additions & 8 deletions baselines/acktr/acktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from baselines.acktr.runner import Runner
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.acktr import kfac
from baselines.ppo2.ppo2 import safemean
from collections import deque


class Model(object):
Expand Down Expand Up @@ -119,6 +121,7 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
model.load(load_path)

runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)
nbatch = nenvs*nsteps
tstart = time.time()
coord = tf.train.Coordinator()
Expand All @@ -135,7 +138,6 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
best_savepath = os.path.join(checkdir, "best")

for update in range(1, total_timesteps//nbatch+1):
# obs, states, rewards, masks, actions, values = runner.run()
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
Expand All @@ -146,14 +148,18 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
ev = explained_variance(values, rewards)
mean_rewbuffer = safemean([epinfo['r'] for epinfo in epinfobuf])
logger.logkv('eprewmean', mean_rewbuffer)

logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update*nbatch)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("policy_loss", float(policy_loss))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(ev))
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
logger.logkv('eprewsem', np.std([epinfo['r'] for epinfo in epinfobuf]))
logger.logkv("nupdates", update)
logger.logkv("total_timesteps", update*nbatch)
logger.logkv("fps", fps)
logger.logkv("policy_entropy", float(policy_entropy))
logger.logkv("policy_loss", float(policy_loss))
logger.logkv("value_loss", float(value_loss))
logger.logkv("explained_variance", float(ev))

logger.dump_tabular()

if save_interval and logger.get_dir():
Expand Down
2 changes: 1 addition & 1 deletion baselines/bench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def register_benchmark(benchmark):
if 'tasks' in benchmark:
for t in benchmark['tasks']:
if 'desc' not in t:
t['desc'] = remove_version_re.sub('', t['env_id'])
t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
_BENCHMARKS.append(benchmark)


Expand Down
52 changes: 27 additions & 25 deletions baselines/bench/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ class Monitor(Wrapper):
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
Wrapper.__init__(self, env=env)
self.tstart = time.time()
self.results_writer = ResultsWriter(
filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords
)
if filename:
self.results_writer = ResultsWriter(filename,
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords
)
else:
self.results_writer = None
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
Expand Down Expand Up @@ -80,8 +82,9 @@ def update(self, ob, rew, done, info):
self.episode_lengths.append(eplen)
self.episode_times.append(time.time() - self.tstart)
epinfo.update(self.current_reset_info)
self.results_writer.write_row(epinfo)

if self.results_writer:
self.results_writer.write_row(epinfo)
assert isinstance(info, dict)
if isinstance(info, dict):
info['episode'] = epinfo

Expand All @@ -91,6 +94,9 @@ def close(self):
if self.f is not None:
self.f.close()

def gg2(self):
return self.env

def get_total_steps(self):
return self.total_steps

Expand All @@ -108,32 +114,28 @@ class LoadMonitorResultsError(Exception):


class ResultsWriter(object):
def __init__(self, filename=None, header='', extra_keys=()):
def __init__(self, filename, header='', extra_keys=()):
self.extra_keys = extra_keys
if filename is None:
self.f = None
self.logger = None
else:
if not filename.endswith(Monitor.EXT):
if osp.isdir(filename):
filename = osp.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
if isinstance(header, dict):
header = '# {} \n'.format(json.dumps(header))
self.f.write(header)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()
assert filename is not None
if not filename.endswith(Monitor.EXT):
if osp.isdir(filename):
filename = osp.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.f = open(filename, "wt")
if isinstance(header, dict):
header = '# {} \n'.format(json.dumps(header))
self.f.write(header)
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
self.logger.writeheader()
self.f.flush()

def write_row(self, epinfo):
if self.logger:
self.logger.writerow(epinfo)
self.f.flush()



def get_monitor_files(dir):
return glob(osp.join(dir, "*" + Monitor.EXT))

Expand Down
9 changes: 5 additions & 4 deletions baselines/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
from .wrappers import TimeLimit


class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
Expand Down Expand Up @@ -221,14 +223,13 @@ def __len__(self):
def __getitem__(self, i):
return self._force()[i]

def make_atari(env_id, timelimit=True):
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
if not timelimit:
env = env.env
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env

def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
Expand Down
16 changes: 10 additions & 6 deletions baselines/common/cmd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ def make_vec_env(env_id, env_type, num_env, seed,
wrapper_kwargs = wrapper_kwargs or {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = seed + 10000 * mpi_rank if seed is not None else None
logger_dir = logger.get_dir()
def make_thunk(rank):
return lambda: make_env(
env_id=env_id,
env_type=env_type,
subrank = rank,
mpi_rank=mpi_rank,
subrank=rank,
seed=seed,
reward_scale=reward_scale,
gamestate=gamestate,
flatten_dict_observations=flatten_dict_observations,
wrapper_kwargs=wrapper_kwargs
wrapper_kwargs=wrapper_kwargs,
logger_dir=logger_dir
)

set_global_seeds(seed)
Expand All @@ -49,8 +52,7 @@ def make_thunk(rank):
return DummyVecEnv([make_thunk(start_index)])


def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
wrapper_kwargs = wrapper_kwargs or {}
if env_type == 'atari':
env = make_atari(env_id)
Expand All @@ -67,12 +69,14 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate

env.seed(seed + subrank if seed is not None else None)
env = Monitor(env,
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
allow_early_resets=True)

if env_type == 'atari':
env = wrap_deepmind(env, **wrapper_kwargs)
elif env_type == 'retro':
if 'frame_stack' not in wrapper_kwargs:
wrapper_kwargs['frame_stack'] = 1
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

if reward_scale != 1:
Expand Down Expand Up @@ -133,6 +137,7 @@ def common_arg_parser():
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
parser.add_argument('--num_timesteps', type=float, default=1e6),
Expand All @@ -144,7 +149,6 @@ def common_arg_parser():
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
parser.add_argument('--play', default=False, action='store_true')
parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None)
return parser

def robotics_arg_parser():
Expand Down
3 changes: 2 additions & 1 deletion baselines/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def fromflat(cls, flat):
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
self.categoricals = list(map(CategoricalPd,
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
def flatparam(self):
return self.flat
def mode(self):
Expand Down
21 changes: 0 additions & 21 deletions baselines/common/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,6 @@ def zipsame(*seqs):
return zip(*seqs)


def unpack(seq, sizes):
"""
Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'.
None = just one bare element, not a list
Example:
unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6])
"""
seq = list(seq)
it = iter(seq)
assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes)
for size in sizes:
if size is None:
yield it.__next__()
else:
li = []
for _ in range(size):
li.append(it.__next__())
yield li


class EzPickle(object):
"""Objects that are pickled and unpickled via their constructor
arguments.
Expand Down
Loading

0 comments on commit bde6006

Please sign in to comment.