diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 95566efea..639ff1fe0 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -559,7 +559,7 @@ def run_experiment(_config: DictConfig) -> float: check_sebulba_config(config) steps_per_rollout = ( - config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval + config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval * config.arch.n_learner_accumulate ) # Logger setup diff --git a/mava/utils/config.py b/mava/utils/config.py index 23484311b..34a35f091 100644 --- a/mava/utils/config.py +++ b/mava/utils/config.py @@ -46,9 +46,11 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: if config.arch.architecture_name == "anakin": n_devices = len(jax.devices()) update_batch_size = config.system.update_batch_size + n_accumulate = 1 # We dont accumulate envs in anakin else: n_devices = 1 # We only use a single device's output when updating. update_batch_size = 1 + n_accumulate = config.arch.n_learner_accumulate if config.system.total_timesteps is None: config.system.num_updates = int(config.system.num_updates) @@ -58,6 +60,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: * config.system.rollout_length * update_batch_size * config.arch.num_envs + * n_accumulate ) else: config.system.total_timesteps = int(config.system.total_timesteps) @@ -67,6 +70,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: // update_batch_size // config.arch.num_envs // n_devices + // n_accumulate ) print( f"{Fore.RED}{Style.BRIGHT} Changing the number of updates "