Skip to content

Commit

Permalink
fix: timestep calculation with accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Oct 17, 2024
1 parent aa49c6f commit c252ffe
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def run_experiment(_config: DictConfig) -> float:
check_sebulba_config(config)

steps_per_rollout = (
config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval
config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval * config.arch.n_learner_accumulate
)

# Logger setup
Expand Down
4 changes: 4 additions & 0 deletions mava/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
if config.arch.architecture_name == "anakin":
n_devices = len(jax.devices())
update_batch_size = config.system.update_batch_size
n_accumulate = 1 # We dont accumulate envs in anakin
else:
n_devices = 1 # We only use a single device's output when updating.
update_batch_size = 1
n_accumulate = config.arch.n_learner_accumulate

if config.system.total_timesteps is None:
config.system.num_updates = int(config.system.num_updates)
Expand All @@ -58,6 +60,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
* config.system.rollout_length
* update_batch_size
* config.arch.num_envs
* n_accumulate
)
else:
config.system.total_timesteps = int(config.system.total_timesteps)
Expand All @@ -67,6 +70,7 @@ def check_total_timesteps(config: DictConfig) -> DictConfig:
// update_batch_size
// config.arch.num_envs
// n_devices
// n_accumulate
)
print(
f"{Fore.RED}{Style.BRIGHT} Changing the number of updates "
Expand Down

0 comments on commit c252ffe

Please sign in to comment.