From 36e3f701d5fa7f02864f747c65568c5d74cf36a6 Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Sun, 28 Apr 2024 22:39:23 +0000 Subject: [PATCH] Clean up API and tests --- README.md | 2 +- dreamerv3/agent.py | 74 +++++----- dreamerv3/configs.yaml | 13 +- dreamerv3/jaxagent.py | 37 +++-- dreamerv3/jaxutils.py | 25 ++-- dreamerv3/main.py | 36 ++++- dreamerv3/nets.py | 14 +- embodied/core/base.py | 4 +- embodied/core/logger.py | 25 ++++ embodied/core/random_agent.py | 7 +- embodied/replay/replay.py | 10 +- embodied/replay/selectors.py | 2 - embodied/run/parallel.py | 10 +- embodied/run/parallel_with_eval.py | 22 ++- embodied/run/train.py | 4 +- embodied/run/train_eval.py | 7 +- embodied/run/train_holdout.py | 7 +- embodied/tests/distr/test_process.py | 2 +- embodied/tests/run/test_parallel.py | 17 ++- embodied/tests/run/test_train.py | 5 +- embodied/tests/run/utils.py | 9 +- embodied/tests/test_driver.py | 6 +- embodied/tests/test_replay.py | 204 +++++++++------------------ 23 files changed, 287 insertions(+), 255 deletions(-) diff --git a/README.md b/README.md index 61b18b8a..a6706e96 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ increases data efficiency. # Instructions -The code has been tested on Linux and Mac. +The code has been tested on Linux and Mac and requires Python 3.11+. ## Docker diff --git a/dreamerv3/agent.py b/dreamerv3/agent.py index 03654f81..c90290bd 100644 --- a/dreamerv3/agent.py +++ b/dreamerv3/agent.py @@ -43,8 +43,6 @@ def __init__(self, obs_space, act_space, config): not k.startswith('log_') and re.match(config.dec.spaces, k)} embodied.print('Encoder:', {k: v.shape for k, v in enc_space.items()}) embodied.print('Decoder:', {k: v.shape for k, v in dec_space.items()}) - # nets.Initializer.VARIANCE_FACTOR = self.config.init_scale - nets.Initializer.FORCE_STDDEV = self.config.winit_scale # World Model self.enc = { @@ -85,16 +83,6 @@ def __init__(self, obs_space, act_space, config): # Optimizer kw = dict(config.opt) lr = kw.pop('lr') - if config.compute_lr: - assert not config.separate_lrs - width = float(config.actor.units) - replay = float(config.run.train_ratio) - a = config.compute_lr_params.lnwidth - b = config.compute_lr_params.lnreplay - c = config.compute_lr_params.bias - lr = np.exp(a * np.log(width) + b * np.log(replay) + c) - message = f'Computed LR (width={width}, replay={replay}): {lr:.1e}' - embodied.print(message) if config.separate_lrs: lr = {f'agent/{k}': v for k, v in config.lrs.items()} self.opt = jaxutils.Optimizer(lr, **kw, name='opt') @@ -114,7 +102,6 @@ def policy_keys(self): @property def aux_spaces(self): - import numpy as np spaces = {} spaces['stepid'] = embodied.Space(np.uint8, 20) if self.config.replay_context: @@ -142,6 +129,9 @@ def init_train(self, batch_size): for k, v in self.act_space.items()} return (self.dyn.initial(batch_size), prevact) + def init_report(self, batch_size): + return self.init_train(batch_size) + def policy(self, obs, carry, mode='train'): self.config.jax.jit and embodied.print( 'Tracing policy function', color='yellow') @@ -332,13 +322,8 @@ def imgstep(carry, _): adv_normed = (adv - aoffset) / ascale logpi = sum([v.log_prob(sg(acts[k]))[:, :-1] for k, v in actor.items()]) ents = {k: v.entropy()[:, :-1] for k, v in actor.items()} - if self.config.scale_by_actent: - actor_loss = sg(weight[:, :-1]) * -( - logpi * sg(adv_normed) * (1 / self.config.actent) + - sum(ents.values())) - else: - actor_loss = sg(weight[:, :-1]) * -( - logpi * sg(adv_normed) + self.config.actent * sum(ents.values())) + actor_loss = sg(weight[:, :-1]) * -( + logpi * sg(adv_normed) + self.config.actent * sum(ents.values())) losses['actor'] = actor_loss # Critic @@ -414,34 +399,55 @@ def imgstep(carry, _): losses = {k: v * self.scales[k] for k, v in losses.items()} loss = jnp.stack([v.mean() for k, v in losses.items()]).sum() newact = {k: data[k][:, -1] for k in self.act_space} - outs = {'replay_outs': replay_outs, 'prevacts': prevacts} + outs = {'replay_outs': replay_outs, 'prevacts': prevacts, 'embed': embed} outs.update({f'{k}_loss': v for k, v in losses.items()}) carry = (newlat, newact) return loss, (outs, carry, metrics) - def report(self, data): + def report(self, data, carry): self.config.jax.jit and embodied.print( 'Tracing report function', color='yellow') if not self.config.report: - return {} + return {}, carry metrics = {} data = self.preprocess(data) # Train metrics - carry = self.init_train(len(data['is_first'])) - _, (outs, _, mets) = self.loss(data, carry, update=False) + _, (outs, carry_out, mets) = self.loss(data, carry, update=False) metrics.update(mets) + # Open loop predictions + B, T = data['is_first'].shape + num_obs = min(self.config.report_openl_context, T // 2) + # Rerun observe to get the correct intermediate state, because + # outs_to_carry doesn't work with num_obs= args.num_envs + embodied.run.parallel_with_eval.parallel_env( + bind(make_env, config), envid, args, True, is_eval) + + elif args.script == 'parallel_with_eval_replay': + embodied.run.parallel_with_eval.parallel_replay( + bind(make_replay, config, 'replay', rate_limit=True), + bind(make_replay, config, 'replay_eval', is_eval=True), args) else: raise NotImplementedError(args.script) @@ -136,7 +162,7 @@ def make_logger(config): def make_replay(config, directory=None, is_eval=False, rate_limit=False): directory = directory and embodied.Path(config.logdir) / directory size = int(config.replay.size / 10 if is_eval else config.replay.size) - length = config.batch_length + length = config.replay_length_eval if is_eval else config.replay_length kwargs = {} kwargs['online'] = config.replay.online if rate_limit and config.run.train_ratio > 0: diff --git a/dreamerv3/nets.py b/dreamerv3/nets.py index e3f45ad4..52b051db 100644 --- a/dreamerv3/nets.py +++ b/dreamerv3/nets.py @@ -230,7 +230,7 @@ class SimpleEncoder(nj.Module): norm: str = 'rms' act: str = 'gelu' kernel: int = 4 - debug_outer: bool = False + outer: bool = False minres: int = 4 def __init__(self, spaces, **kw): @@ -260,7 +260,7 @@ def __call__(self, data, bdims=2): x = self.imginp(data, bdims, jaxutils.COMPUTE_DTYPE) - 0.5 x = x.reshape((-1, *x.shape[bdims:])) for i, depth in enumerate(self.depths): - stride = 1 if self.debug_outer and i == 0 else 2 + stride = 1 if self.outer and i == 0 else 2 x = self.get(f'conv{i}', Conv2D, depth, self.kernel, stride, **kw)(x) assert x.shape[-3] == x.shape[-2] == self.minres, x.shape x = x.reshape((x.shape[0], -1)) @@ -285,7 +285,7 @@ class SimpleDecoder(nj.Module): outscale: float = 1.0 vecdist: str = 'symlog_mse' kernel: int = 4 - debug_outer: bool = False + outer: bool = False block_fans: bool = False block_norm: bool = False block_space: int = 0 @@ -351,7 +351,7 @@ def __call__(self, lat, bdims=2): x = self.get( f'conv{i}', Conv2D, depth, self.kernel, 2, **kw, transp=True)(x) outkw = dict(**self.kw, outscale=self.outscale, transp=True) - stride = 1 if self.debug_outer else 2 + stride = 1 if self.outer else 2 x = self.get( 'imgout', Conv2D, self.imgdep, self.kernel, stride, **outkw)(x) x = jax.nn.sigmoid(x) if self.sigmoid else x + 0.5 @@ -816,7 +816,6 @@ def __call__(self, inputs, bdims=2, dtype=None): class Initializer: VARIANCE_FACTOR = 1.0 - FORCE_STDDEV = 0.0 def __init__( self, dist='normal', scale=1.0, fan='in', dtype='default', @@ -841,10 +840,7 @@ def __call__(self, shape, fan_shape=None): value = jax.random.uniform(nj.seed(), shape, dtype, -limit, limit) elif self.dist == 'normal': value = jax.random.truncated_normal(nj.seed(), -2, 2, shape) - if self.FORCE_STDDEV > 0.0: - value *= 1.1368 * self.FORCE_STDDEV - else: - value *= 1.1368 * np.sqrt(self.VARIANCE_FACTOR / fan) + value *= 1.1368 * np.sqrt(self.VARIANCE_FACTOR / fan) value = value.astype(dtype) elif self.dist == 'normed': value = jax.random.uniform(nj.seed(), shape, dtype, -1, 1) diff --git a/embodied/core/base.py b/embodied/core/base.py index 76317896..189d4666 100644 --- a/embodied/core/base.py +++ b/embodied/core/base.py @@ -25,9 +25,9 @@ def train(self, data, carry=None): raise NotImplementedError( 'train(data, carry=None) -> outs, carry, metrics') - def report(self, data): + def report(self, data, carry=None): raise NotImplementedError( - 'report(data) -> metrics') + 'report(data, carry=None) -> metrics, carry') def dataset(self, generator_fn): raise NotImplementedError( diff --git a/embodied/core/logger.py b/embodied/core/logger.py index d8a59d2b..a88ee184 100644 --- a/embodied/core/logger.py +++ b/embodied/core/logger.py @@ -348,6 +348,31 @@ def _setup(self, run_name, resume_id, config): self._mlflow.start_run(run_name=run_name, tags=tags) +class ExpaOutput: + + def __init__(self, exp, run, project, user, config=None): + try: + import expa + print(f'Expa: {exp}/{run} ({project})') + self._expa = expa.Logger( + exp, run, project, user, api_url='pubsub://expa-dev/ingest') + if config: + self._expa.log_params(dict(config)) + except Exception as e: + print(f'Error exporting Expa: {e}') + self._expa = None + return + + def __call__(self, summaries): + if not self._expa: + return + bystep = collections.defaultdict(dict) + for step, name, value in summaries: + bystep[step][name] = value + for step, metrics in bystep.items(): + self._expa.log(metrics, step) + + @timer.section('gif') def _encode_gif(frames, fps): from subprocess import Popen, PIPE diff --git a/embodied/core/random_agent.py b/embodied/core/random_agent.py index 9ad033f9..ef04edfa 100644 --- a/embodied/core/random_agent.py +++ b/embodied/core/random_agent.py @@ -13,6 +13,9 @@ def init_policy(self, batch_size): def init_train(self, batch_size): return () + def init_report(self, batch_size): + return () + def policy(self, obs, carry=(), mode='train'): batch_size = len(obs['is_first']) act = { @@ -26,9 +29,9 @@ def train(self, data, carry=()): metrics = {} return outs, carry, metrics - def report(self, data): + def report(self, data, carry=()): report = {} - return report + return report, carry def dataset(self, generator): return generator() diff --git a/embodied/replay/replay.py b/embodied/replay/replay.py index d95d566f..e8e6da26 100644 --- a/embodied/replay/replay.py +++ b/embodied/replay/replay.py @@ -17,7 +17,7 @@ class Replay: def __init__( self, length, capacity=None, directory=None, chunksize=1024, min_size=1, samples_per_insert=None, tolerance=1e4, online=False, selector=None, - debug_save_wait=False, seed=0): + save_wait=False, seed=0): assert not capacity or min_size <= capacity self.length = length @@ -57,7 +57,7 @@ def __init__( else: self.directory = None - self.debug_save_wait = debug_save_wait + self.save_wait = save_wait self.metrics = { 'samples': 0, @@ -285,8 +285,8 @@ def _assemble_batch(self, seqs, start, stop): for key, parts in seqs[0].items()} for n, seq in enumerate(seqs): st, dt = 0, 0 # Source and destination time index. - for p in range(len(seq['is_first'])): - partlen = len(seq['is_first'][p]) + for p in range(len(seq['stepid'])): + partlen = len(seq['stepid'][p]) if start < st + partlen: part_start = max(0, start - st) part_stop = min(stop - st, partlen) @@ -329,7 +329,7 @@ def save(self): if chunk.length > 0 and chunk.uuid not in self.saved: self.saved.add(chunk.uuid) promises.append(self.workers.submit(chunk.save, self.directory)) - if self.debug_save_wait: + if self.save_wait: [promise.result() for promise in promises] return {'limiter': self.limiter.save()} diff --git a/embodied/replay/selectors.py b/embodied/replay/selectors.py index e6640a12..0f50fcc7 100644 --- a/embodied/replay/selectors.py +++ b/embodied/replay/selectors.py @@ -111,8 +111,6 @@ def _build(self, uprobs, bfactor=16): class Prioritized: - # TODO: Checkpoint priorities. - def __init__( self, exponent=1.0, initial=1.0, zero_on_sample=False, maxfrac=0.0, branching=16, seed=0): diff --git a/embodied/run/parallel.py b/embodied/run/parallel.py index ae50e606..fa19d544 100644 --- a/embodied/run/parallel.py +++ b/embodied/run/parallel.py @@ -110,8 +110,8 @@ def parallel_learner(agent, barrier, args): should_log = embodied.when.Clock(args.log_every) should_eval = embodied.when.Clock(args.eval_every) should_save = embodied.when.Clock(args.save_every) - batch_steps = args.batch_size * (args.batch_length - args.replay_context) fps = embodied.FPS() + batch_steps = args.batch_size * (args.batch_length - args.replay_context) checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt') checkpoint.agent = agent @@ -140,7 +140,8 @@ def parallel_dataset(source, prefetch=2): dataset_train = agent.dataset(bind(parallel_dataset, 'sample_batch_train')) dataset_report = agent.dataset(bind(parallel_dataset, 'sample_batch_report')) - state = agent.init_train(args.batch_size) + carry = agent.init_train(args.batch_size) + carry_report = agent.init_report(args.batch_size) should_save() # Delay first save. should_eval() # Delay first eval. @@ -149,7 +150,7 @@ def parallel_dataset(source, prefetch=2): with embodied.timer.section('learner_batch_next'): batch = next(dataset_train) with embodied.timer.section('learner_train_step'): - outs, state, mets = agent.train(batch, state) + outs, carry, mets = agent.train(batch, carry) if 'replay' in outs: with embodied.timer.section('learner_replay_update'): updater.update(outs['replay']) @@ -159,7 +160,8 @@ def parallel_dataset(source, prefetch=2): if should_eval(): with embodied.timer.section('learner_eval'): - logger.add(prefix(agent.report(next(dataset_report)), 'report')) + mets, _ = agent.report(next(dataset_report), carry_report) + logger.add(prefix(mets, 'report')) if should_log(): with embodied.timer.section('learner_metrics'): diff --git a/embodied/run/parallel_with_eval.py b/embodied/run/parallel_with_eval.py index 7422db7c..c806062f 100644 --- a/embodied/run/parallel_with_eval.py +++ b/embodied/run/parallel_with_eval.py @@ -119,8 +119,8 @@ def parallel_learner(agent, barrier, args): should_log = embodied.when.Clock(args.log_every) should_eval = embodied.when.Clock(args.eval_every) should_save = embodied.when.Clock(args.save_every) - batch_steps = args.batch_size * (args.batch_length - args.replay_context) fps = embodied.FPS() + batch_steps = args.batch_size * (args.batch_length - args.replay_context) checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt') checkpoint.agent = agent @@ -150,10 +150,20 @@ def parallel_dataset(source, prefetch=2): received[source] += 1 yield batch + def evaluate(dataset): + num_batches = args.replay_length_eval // args.batch_length_eval + carry = agent.init_report(args.batch_size) + agg = embodied.Agg() + for _ in range(num_batches): + batch = next(dataset) + metrics, carry = agent.report(batch, carry) + agg.add(metrics) + return agg.result() + dataset_train = agent.dataset(bind(parallel_dataset, 'train')) dataset_report = agent.dataset(bind(parallel_dataset, 'report')) dataset_eval = agent.dataset(bind(parallel_dataset, 'eval')) - state = agent.init_train(args.batch_size) + carry = agent.init_train(args.batch_size) should_save() # Delay first save. should_eval() # Delay first eval. @@ -162,7 +172,7 @@ def parallel_dataset(source, prefetch=2): with embodied.timer.section('learner_batch_next'): batch = next(dataset_train) with embodied.timer.section('learner_train_step'): - outs, state, mets = agent.train(batch, state) + outs, carry, mets = agent.train(batch, carry) if 'replay' in outs: with embodied.timer.section('learner_replay_update'): updater.update(outs['replay']) @@ -173,9 +183,9 @@ def parallel_dataset(source, prefetch=2): if should_eval(): with embodied.timer.section('learner_eval'): if received['report'] > 0: - logger.add(prefix(agent.report(next(dataset_report)), 'report')) + logger.add(prefix(evaluate(dataset_report), 'report')) if received['eval'] > 0: - logger.add(prefix(agent.report(next(dataset_eval)), 'eval')) + logger.add(prefix(evaluate(dataset_eval), 'eval')) if should_log(): with embodied.timer.section('learner_metrics'): @@ -355,7 +365,7 @@ def parallel_env(make_env, envid, args, logging=False, is_eval=False): _print = lambda x: embodied.print(f'[{name}] {x}', flush=True) should_log = embodied.when.Clock(args.log_every) - if logging: + if logging and envid == 0: logger = embodied.distr.Client( args.logger_addr, f'{name}Logger', args.ipv6, maxinflight=1, connect=True) diff --git a/embodied/run/train.py b/embodied/run/train.py index d38856b3..0f61ef97 100644 --- a/embodied/run/train.py +++ b/embodied/run/train.py @@ -76,6 +76,7 @@ def log_step(tran, worker): dataset_report = iter(agent.dataset(bind( replay.dataset, args.batch_size, args.batch_length_eval))) carry = [agent.init_train(args.batch_size)] + carry_report = agent.init_report(args.batch_size) def train_step(tran, worker): if len(replay) < args.batch_size or step < args.train_fill: @@ -108,7 +109,8 @@ def train_step(tran, worker): driver(policy, steps=10) if should_eval(step) and len(replay): - logger.add(agent.report(next(dataset_report)), prefix='report') + mets, _ = agent.report(next(dataset_report), carry_report) + logger.add(mets, prefix='report') if should_log(step): logger.add(agg.result()) diff --git a/embodied/run/train_eval.py b/embodied/run/train_eval.py index a53744cd..133dc8d5 100644 --- a/embodied/run/train_eval.py +++ b/embodied/run/train_eval.py @@ -91,6 +91,7 @@ def log_step(tran, worker, mode): dataset_eval = agent.dataset( bind(eval_replay.dataset, args.batch_size, args.batch_length_eval)) carry = [agent.init_train(args.batch_size)] + carry_report = agent.init_report(args.batch_size) def train_step(tran, worker): if len(train_replay) < args.batch_size or step < args.train_fill: @@ -128,9 +129,11 @@ def train_step(tran, worker): eval_driver(eval_policy, episodes=args.eval_eps) logger.add(eval_epstats.result(), prefix='epstats') if len(train_replay): - logger.add(agent.report(next(dataset_report)), prefix='report') + mets, _ = agent.report(next(dataset_report), carry_report) + logger.add(mets, prefix='report') if len(eval_replay): - logger.add(agent.report(next(dataset_eval)), prefix='eval') + mets, _ = agent.report(next(dataset_eval), carry_report) + logger.add(mets, prefix='eval') train_driver(train_policy, steps=10) diff --git a/embodied/run/train_holdout.py b/embodied/run/train_holdout.py index 88b74bc4..b7d2a6a8 100644 --- a/embodied/run/train_holdout.py +++ b/embodied/run/train_holdout.py @@ -81,6 +81,7 @@ def log_step(tran, worker): bind(eval_replay.dataset, args.batch_size, args.batch_length_eval)) carry = [agent.init_train(args.batch_size)] + carry_report = agent.init_report(args.batch_size) def train_step(tran, worker): if len(train_replay) < args.batch_size or step < args.train_fill: @@ -117,9 +118,11 @@ def train_step(tran, worker): logger.add(agg.result()) logger.add(epstats.result(), prefix='epstats') if len(train_replay): - logger.add(agent.report(next(dataset_report)), prefix='report') + mets, _ = agent.report(next(dataset_report), init_report) + logger.add(mets, prefix='report') if len(eval_replay): - logger.add(agent.report(next(dataset_eval)), prefix='eval') + mets, _ = agent.report(next(dataset_eval), init_report) + logger.add(mets, prefix='eval') logger.add(embodied.timer.stats(), prefix='timer') logger.add(train_replay.stats(), prefix='replay') logger.add(usage.stats(), prefix='usage') diff --git a/embodied/tests/distr/test_process.py b/embodied/tests/distr/test_process.py index 34cb7304..89bd1624 100644 --- a/embodied/tests/distr/test_process.py +++ b/embodied/tests/distr/test_process.py @@ -49,7 +49,7 @@ def fn1234(q): q = mp.get_context().SimpleQueue() worker = embodied.distr.Process(fn1234, q, start=True) q.get() - time.sleep(0.1) + time.sleep(0.5) assert not worker.running assert worker.exitcode == 1 with pytest.raises(KeyError) as info: diff --git a/embodied/tests/run/test_parallel.py b/embodied/tests/run/test_parallel.py index 7a0735e0..2db0677c 100644 --- a/embodied/tests/run/test_parallel.py +++ b/embodied/tests/run/test_parallel.py @@ -28,7 +28,7 @@ def test_run_loop(self, tmpdir, train_ratio): for key in ('actor_addr', 'replay_addr', 'logger_addr'): ports.append(args[key].replace('-', ':').split(':')[-1]) - embodied.run.parallel( + embodied.run.parallel.combined( bind(self._make_agent, addr), bind(self._make_replay, args), self._make_env, self._make_logger, args) @@ -44,17 +44,17 @@ def test_run_loop(self, tmpdir, train_ratio): assert stats['reports'] >= 1 assert stats['saves'] >= 2 assert stats['loads'] == 0 - for port in ports: - assert embodied.distr.port_free(port) + # for port in ports: + # assert embodied.distr.port_free(port) - embodied.run.parallel( + embodied.run.parallel.combined( bind(self._make_agent, addr), bind(self._make_replay, args), self._make_env, self._make_logger, args) stats = received[0] assert stats['loads'] == 1 - for port in ports: - assert embodied.distr.port_free(port) + # for port in ports: + # assert embodied.distr.port_free(port) def _make_agent(self, queue): env = self._make_env(0) @@ -87,10 +87,13 @@ def _make_args(self, logdir, train_ratio): duration=10, log_every=3, save_every=5, + eval_every=5, train_ratio=float(train_ratio), train_fill=100, batch_size=8, batch_length=16, + batch_length_eval=8, + replay_context=0, expl_until=0, from_checkpoint='', usage=dict(psutil=True, nvsmi=False), @@ -112,4 +115,6 @@ def _make_args(self, logdir, train_ratio): env_replica=-1, ipv6=False, timer=True, + agent_process=False, + remote_replay=False, ) diff --git a/embodied/tests/run/test_train.py b/embodied/tests/run/test_train.py index 717ba7e5..ee9b7349 100644 --- a/embodied/tests/run/test_train.py +++ b/embodied/tests/run/test_train.py @@ -61,13 +61,16 @@ def _make_args(self, logdir): return embodied.Config( logdir=str(logdir), num_envs=4, - steps=2e4, + steps=5e4, log_every=3, save_every=5, + eval_every=5, train_ratio=32.0, train_fill=100, batch_size=8, batch_length=16, + batch_length_eval=8, + replay_context=0, expl_until=0, from_checkpoint='', usage=dict(psutil=True, nvsmi=False), diff --git a/embodied/tests/run/utils.py b/embodied/tests/run/utils.py index 34d3f8ea..a19a3505 100644 --- a/embodied/tests/run/utils.py +++ b/embodied/tests/run/utils.py @@ -37,6 +37,9 @@ def init_policy(self, batch_size): def init_train(self, batch_size): return (np.zeros(batch_size),) + def init_report(self, batch_size): + return () + def policy(self, obs, carry, mode='train'): B = len(obs['is_first']) self._stats['env_steps'] += B @@ -56,7 +59,7 @@ def policy(self, obs, carry, mode='train'): act = { k: np.stack([v.sample() for _ in range(B)]) for k, v in self.act_space.items() if k != 'reset'} - return act, (carry,) + return act, {}, (carry,) def train(self, data, carry): B, T = data['step'].shape @@ -75,7 +78,7 @@ def train(self, data, carry): metrics = {} return outs, (carry,), metrics - def report(self, data): + def report(self, data, carry): self._stats['reports'] += 1 return { 'scalar': np.float32(0), @@ -83,7 +86,7 @@ def report(self, data): 'image1': np.zeros((64, 64, 1)), 'image3': np.zeros((64, 64, 3)), 'video': np.zeros((10, 64, 64, 3)), - } + }, carry def dataset(self, generator): return generator() diff --git a/embodied/tests/test_driver.py b/embodied/tests/test_driver.py index 53c94c92..26dbbf60 100644 --- a/embodied/tests/test_driver.py +++ b/embodied/tests/test_driver.py @@ -52,7 +52,7 @@ def test_env_reset(self): seq = [] driver.on_step(lambda tran, _: seq.append(tran)) action = np.array([1]) - driver(lambda obs, state: ({'action': action}, state), episodes=2) + driver(lambda obs, state: ({'action': action}, {}, state), episodes=2) assert len(seq) == 12 seq = {k: np.array([seq[i][k] for i in range(len(seq))]) for k in seq[0]} assert (seq['is_first'] == [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]).all() @@ -69,8 +69,8 @@ def test_agent_inputs(self): def policy(obs, state=None, mode='train'): inputs.append(obs) states.append(state) - act, _ = agent.policy(obs, state, mode) - return act, 'state' + act, _, _ = agent.policy(obs, state, mode) + return act, {}, 'state' seq = [] driver.on_step(lambda tran, _: seq.append(tran)) driver(policy, episodes=2) diff --git a/embodied/tests/test_replay.py b/embodied/tests/test_replay.py index 220a4a40..318dc83c 100644 --- a/embodied/tests/test_replay.py +++ b/embodied/tests/test_replay.py @@ -25,6 +25,11 @@ ] +def unbatched(dataset): + for batch in dataset: + yield {k: v[0] for k, v in batch.items()} + + @pytest.mark.filterwarnings('ignore:.*Pillow.*') @pytest.mark.filterwarnings('ignore:.*the imp module.*') @pytest.mark.filterwarnings('ignore:.*distutils.*') @@ -35,10 +40,9 @@ def test_multiple_keys(self, Replay): replay = Replay(length=5, capacity=10) for step in range(30): replay.add({'image': np.zeros((64, 64, 3)), 'action': np.zeros(12)}) - seq = next(iter(replay.dataset())) - # assert set(seq.keys()) == {'id', 'image', 'action'} - assert set(seq.keys()) == {'image', 'action'} - # assert seq['id'].shape == (5, 16) + seq = next(unbatched(replay.dataset(1))) + assert set(seq.keys()) == {'stepid', 'image', 'action'} + assert seq['stepid'].shape == (5, 20) assert seq['image'].shape == (5, 64, 64, 3) assert seq['action'].shape == (5, 12) @@ -65,7 +69,7 @@ def test_sample_sequences( for step in range(30): for worker in range(workers): replay.add({'step': step, 'worker': worker}, worker) - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(10): seq = next(dataset) assert (seq['step'] - seq['step'][0] == np.arange(length)).all() @@ -78,7 +82,7 @@ def test_sample_single(self, Replay, length, capacity): replay = Replay(length, capacity) for step in range(length): replay.add({'step': step}) - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(10): seq = next(dataset) assert (seq['step'] == np.arange(length)).all() @@ -90,7 +94,7 @@ def test_sample_uniform(self, Replay): replay.add({'step': step}) assert len(replay) == 3 histogram = collections.defaultdict(int) - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(100): seq = next(dataset) histogram[seq['step'][0]] += 1 @@ -107,7 +111,7 @@ def test_workers_simple(self, Replay): replay.add({'step': 1}, worker=1) replay.add({'step': 2}, worker=0) replay.add({'step': 3}, worker=1) - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(10): seq = next(dataset) assert tuple(seq['step']) in ((0, 2), (1, 3)) @@ -125,7 +129,7 @@ def test_workers_random(self, Replay, length=4, capacity=30): except StopIteration: pass histogram = collections.defaultdict(int) - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(10): seq = next(dataset) assert (seq['step'] - seq['step'][0] == np.arange(length)).all() @@ -155,16 +159,18 @@ def test_worker_delay(self, Replay, length, workers, capacity): [(1, 1, 128), (3, 10, 128), (5, 100, 128), (5, 25, 2)]) def test_restore_exact(self, tmpdir, Replay, length, capacity, chunksize): embodied.uuid.reset(debug=True) - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) for step in range(30): replay.add({'step': step}) num_items = np.clip(30 - length + 1, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() replay = Replay(length, capacity, directory=tmpdir) - replay.load() + replay.load(data) assert len(replay) == num_items - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(len(replay)): assert len(next(dataset)['step']) == length @@ -174,16 +180,18 @@ def test_restore_exact(self, tmpdir, Replay, length, capacity, chunksize): [(1, 1, 128), (3, 10, 128), (5, 100, 128), (5, 25, 2)]) def test_restore_noclear(self, tmpdir, Replay, length, capacity, chunksize): embodied.uuid.reset(debug=True) - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) for _ in range(30): replay.add({'foo': 13}) num_items = np.clip(30 - length + 1, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() for _ in range(30): replay.add({'foo': 42}) - replay.load() - dataset = iter(replay.dataset()) + replay.load(data) + dataset = unbatched(replay.dataset(1)) if capacity < num_items: for _ in range(len(replay)): assert next(dataset)['foo'] == 13 @@ -193,17 +201,18 @@ def test_restore_noclear(self, tmpdir, Replay, length, capacity, chunksize): @pytest.mark.parametrize('length,capacity', [(1, 1), (3, 10), (5, 100)]) def test_restore_workers(self, tmpdir, Replay, workers, length, capacity): capacity *= workers - replay = Replay(length, capacity, directory=tmpdir) + replay = Replay( + length, capacity, directory=tmpdir, save_wait=True) for step in range(50): for worker in range(workers): replay.add({'step': step}, worker) num_items = np.clip((50 - length + 1) * workers, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() replay = Replay(length, capacity, directory=tmpdir) - replay.load() + replay.load(data) assert len(replay) == num_items - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(len(replay)): assert len(next(dataset)['step']) == length @@ -214,22 +223,29 @@ def test_restore_chunks_exact( self, tmpdir, Replay, length, capacity, chunksize): embodied.uuid.reset(debug=True) assert len(list(embodied.Path(tmpdir).glob('*.npz'))) == 0 - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) for step in range(30): replay.add({'step': step}) num_items = np.clip(30 - length + 1, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() filenames = list(embodied.Path(tmpdir).glob('*.npz')) lengths = [int(x.stem.split('-')[3]) for x in filenames] - assert len(filenames) == (int(np.ceil(30 / chunksize))) - assert sum(lengths) == 30 + stored_steps = min(capacity + length - 1, 30) + total_chunks = int(np.ceil(30 / chunksize)) + pruned_chunks = int(np.floor((30 - stored_steps) / chunksize)) + assert len(filenames) == total_chunks - pruned_chunks + last_chunk_empty = total_chunks * chunksize - 30 + saved_steps = (total_chunks - pruned_chunks) * chunksize - last_chunk_empty + assert sum(lengths) == saved_steps assert all(1 <= x <= chunksize for x in lengths) replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) - replay.load() + replay.load(data) assert sorted(embodied.Path(tmpdir).glob('*.npz')) == sorted(filenames) assert len(replay) == num_items - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(len(replay)): assert len(next(dataset)['step']) == length @@ -240,20 +256,28 @@ def test_restore_chunks_exact( def test_restore_chunks_workers( self, tmpdir, Replay, workers, length, capacity, chunksize): capacity *= workers - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) for step in range(50): for worker in range(workers): replay.add({'step': step}, worker) num_items = np.clip((50 - length + 1) * workers, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() filenames = list(embodied.Path(tmpdir).glob('*.npz')) lengths = [int(x.stem.split('-')[3]) for x in filenames] - assert sum(lengths) == 50 * workers + stored_steps = min(capacity // workers + length - 1, 50) + total_chunks = int(np.ceil(50 / chunksize)) + pruned_chunks = int(np.floor((50 - stored_steps) / chunksize)) + assert len(filenames) == (total_chunks - pruned_chunks) * workers + last_chunk_empty = total_chunks * chunksize - 50 + saved_steps = (total_chunks - pruned_chunks) * chunksize - last_chunk_empty + assert sum(lengths) == saved_steps * workers replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) - replay.load() + replay.load(data) assert len(replay) == num_items - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(len(replay)): assert len(next(dataset)['step']) == length @@ -263,17 +287,19 @@ def test_restore_chunks_workers( [(1, 1, 128), (3, 10, 128), (5, 100, 128), (5, 25, 2)]) def test_restore_insert(self, tmpdir, Replay, length, capacity, chunksize): embodied.uuid.reset(debug=True) - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) inserts = int(1.5 * chunksize) for step in range(inserts): replay.add({'step': step}) num_items = np.clip(inserts - length + 1, 0, capacity) assert len(replay) == num_items - replay.save(wait=True) + data = replay.save() replay = Replay(length, capacity, directory=tmpdir) - replay.load() + replay.load(data) assert len(replay) == num_items - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) for _ in range(len(replay)): assert len(next(dataset)['step']) == length for step in range(inserts): @@ -286,7 +312,9 @@ def test_threading( self, tmpdir, Replay, length=5, capacity=128, chunksize=32, adders=8, samplers=4): embodied.uuid.reset(debug=True) - replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) + replay = Replay( + length, capacity, directory=tmpdir, chunksize=chunksize, + save_wait=True) running = [True] def adder(): @@ -298,8 +326,7 @@ def adder(): time.sleep(0.001) def sampler(): - ident = threading.get_ident() - dataset = iter(replay.dataset()) + dataset = unbatched(replay.dataset(1)) while running[0]: seq = next(dataset) assert (seq['step'] - seq['step'][0] == np.arange(length)).all() @@ -321,109 +348,14 @@ def sampler(): assert stats['samples'] > 0 print('SAVING') - replay.save(wait=True) + data = replay.save() time.sleep(0.1) print('LOADING') - # replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) - # replay.clear() - replay.load() + replay.load(data) finally: running[0] = False [worker.join() for worker in workers] assert len(replay) == capacity - - - # @pytest.mark.parametrize('Replay', REPLAYS_UNLIMITED) - # @pytest.mark.parametrize( - # 'length,capacity,chunksize,workers', - # [(3, 100, 16, 4)]) - # def test_restore_capacity( - # self, tmpdir, Replay, length, capacity, chunksize, workers): - # embodied.uuid.reset(debug=True) - # replay = Replay(length, capacity, directory=tmpdir, chunksize=chunksize) - - # for step in range(500): - # for worker in range(workers): - # replay.add({'step': step}, worker=worker) - - # num_items = np.clip(workers * (500 - length + 1), 0, capacity) - # assert len(replay) == num_items - # replay.save(wait=True) - # replay = Replay(length, capacity, directory=tmpdir) - - # rng = np.random.default_rng(seed=0) - # filenames = sorted(embodied.Path(tmpdir).glob('*.npz')) - # # for filename in rng.choice(filenames, min(100, len(filenames)), replace=False): - # for filename in filenames[:120]: - # filename.remove() - - # # for step in range(10): - # # for worker in range(workers): - # # replay.add({'step': step}, worker=worker) - - # replay.load() - # assert len(replay) == num_items - - # # dataset = iter(replay.dataset()) - # # for _ in range(len(replay)): - # # assert len(next(dataset)['step']) == length - # # for step in range(inserts): - # # replay.add({'step': step}) - # # num_items = np.clip(2 * (inserts - length + 1), 0, capacity) - # # assert len(replay) == num_items - - - # @pytest.mark.parametrize('Replay', REPLAYS_QUEUES) - # @pytest.mark.parametrize( - # 'length,capacity,overlap', - # [(1, 1, 0), (5, 10, 3), (10, 5, 2)]) - # def test_queue_single(self, Replay, length, capacity, overlap): - # replay = Replay(length, capacity, overlap=overlap) - # for step in range(length): - # replay.add({'step': step}) - # dataset = iter(replay.dataset()) - # seq = next(dataset) - # assert (seq['step'] == np.arange(length)).all() - - # @pytest.mark.parametrize('Replay', REPLAYS_QUEUES) - # @pytest.mark.parametrize( - # 'length,capacity,overlap', - # [(1, 5, 0), (2, 5, 1), (5, 10, 3), (10, 5, 0), (10, 5, 2)]) - # def test_queue_order(self, Replay, length, capacity, overlap): - # assert overlap < length - # assert 5 <= capacity - # replay = Replay(length, capacity, overlap=overlap) - # inserts = length + 4 * (length - overlap) - # for step in range(inserts): - # replay.add({'step': step}) - # dataset = iter(replay.dataset()) - # for index in range(len(replay)): - # seq = next(dataset) - # start = index * (length - overlap) - # assert seq['step'][0] == start - # assert (seq['step'] - start == np.arange(length)).all() - - # @pytest.mark.parametrize('Replay', REPLAYS_QUEUES) - # @pytest.mark.parametrize( - # 'length,capacity,overlap,workers', - # [(1, 10, 0, 2), (2, 10, 1, 2), (5, 30, 3, 4)]) - # def test_queue_workers(self, Replay, length, capacity, overlap, workers): - # assert overlap < length - # assert 5 * workers <= capacity - # replay = Replay(length, capacity, overlap=overlap) - # inserts = length + 4 * (length - overlap) - # for step in range(inserts): - # for worker in range(workers): - # replay.add({'step': step, 'worker': worker}, worker) - # dataset = iter(replay.dataset()) - # assert len(replay) == 5 * workers - # for index in range(5): - # for worker in range(workers): - # seq = next(dataset) - # start = index * (length - overlap) - # assert seq['step'][0] == start - # assert (seq['worker'] == worker).all() - # assert (seq['step'] - start == np.arange(length)).all()