From ece8b4170094fd116f64324d20be53cada131bb2 Mon Sep 17 00:00:00 2001 From: KuoHaoZeng Date: Fri, 6 Sep 2024 09:37:10 -0700 Subject: [PATCH] add replay buffer --- allenact/algorithms/onpolicy_sync/engine.py | 31 +++++++++++---------- allenact/algorithms/onpolicy_sync/runner.py | 4 +++ allenact/main.py | 20 +++++++++++++ 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index f3c71e3c..ce2eea4d 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -1200,8 +1200,8 @@ def __init__( save_ckpt_after_every_pipeline_stage: bool = True, first_local_worker_id: int = 0, save_ckpt_at_every_host: bool = False, - offpolicy_batch_size: int = 32, - replay_buffer_max_size: int = 640, + offpolicy_batch_size: int = 0, + offpolicy_max_batch_size: int = 640, **kwargs, ): kwargs["mode"] = TRAIN_MODE_STR @@ -1231,14 +1231,15 @@ def __init__( self.training_pipeline: TrainingPipeline = config.training_pipeline() # [OFFP] - self.replay_buffer = ReplayBuffer( - storage=LazyMemmapStorage( - max_size=replay_buffer_max_size, - device=torch.device("cpu"), - scratch_dir="/tmp/replay_buffer/", - ), - batch_size=offpolicy_batch_size, - ) + if offpolicy_batch_size > 0: + self.replay_buffer = ReplayBuffer( + storage=LazyMemmapStorage( + max_size=offpolicy_max_batch_size, + device=torch.device("cpu"), + scratch_dir="/tmp/replay_buffer/", + ), + batch_size=offpolicy_batch_size, + ) if self.num_workers != 1: # Ensure that we're only using early stopping criterions in the non-distributed setting. @@ -1851,9 +1852,10 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): for storage in self.training_pipeline.current_stage_storage.values(): storage.before_updates(**before_update_info) - adapted_storage = StorageAdapter(storage, torch.device("cpu")) - tensordict = adapted_storage.to_tensordict(batch_size=[storage.rewards.shape[1]]) - self.replay_buffer.extend(tensordict) + if self.replay_buffer is not None: + adapted_storage = StorageAdapter(storage, torch.device("cpu")) + tensordict = adapted_storage.to_tensordict(batch_size=[storage.rewards.shape[1]]) + self.replay_buffer.extend(tensordict) for sc in self.training_pipeline.current_stage.stage_components: component_storage = uuid_to_storage[sc.storage_uuid] @@ -1861,7 +1863,8 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): self.compute_losses_track_them_and_backprop( stage=self.training_pipeline.current_stage, stage_component=sc, - storage=component_storage, + storage=component_storage if self.replay_buffer is not None else None, + replay_buffer=self.replay_buffer, ) for storage in self.training_pipeline.current_stage_storage.values(): diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index 84f4f2c5..2db698d5 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -502,6 +502,8 @@ def start_train( valid_on_initial_weights: bool = False, try_restart_after_task_error: bool = False, save_ckpt_at_every_host: bool = False, + offpolicy_batch_size: Optional[int] = 0, + offpolicy_max_batch_size: Optional[int] = 640, ): self._initialize_start_train_or_start_test() @@ -574,6 +576,8 @@ def start_train( valid_on_initial_weights=valid_on_initial_weights, try_restart_after_task_error=try_restart_after_task_error, save_ckpt_at_every_host=save_ckpt_at_every_host, + offpolicy_batch_size=offpolicy_batch_size, + offpolicy_max_batch_size=offpolicy_max_batch_size, ) train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, diff --git a/allenact/main.py b/allenact/main.py index 827d5205..3263c994 100755 --- a/allenact/main.py +++ b/allenact/main.py @@ -284,6 +284,24 @@ def get_argument_parser(): ) parser.set_defaults(save_ckpt_at_every_host=False) + parser.add_argument( + "--offpolicy_batch_size", + dest="offpolicy_batch_size", + required=False, + type=int, + default=0, + help="Batch size for off-policy training (default: 0, i.e. on-policy training).", + ) + + parser.add_argument( + "--offpolicy_max_batch_size", + dest="offpolicy_max_batch_size", + required=False, + type=int, + default=640, + help="Max batch size for replay buffer used for off-policy training.", + ) + parser.add_argument( "--callbacks", dest="callbacks", @@ -495,6 +513,8 @@ def main(): valid_on_initial_weights=args.valid_on_initial_weights, try_restart_after_task_error=args.enable_crash_recovery, save_ckpt_at_every_host=args.save_ckpt_at_every_host, + offpolicy_batch_size=args.offpolicy_batch_size, + offpolicy_max_batch_size=args.offpolicy_max_batch_size, ) else: OnPolicyRunner(