diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index 1b77eae4..3f98a913 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -1064,7 +1064,11 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin): if training: aggregate_bsize = self.distributed_weighted_sum(bsize, 1) to_track["global_batch_size"] = aggregate_bsize - to_track["lr"] = self.optimizer.param_groups[0]["lr"] + if len(self.optimizer.param_groups) >= 2: + for i, param_group in enumerate(self.optimizer.param_groups): + to_track[f"lr_group_{i}"] = param_group["lr"] + else: + to_track["lr"] = self.optimizer.param_groups[0]["lr"] if training_settings.num_mini_batch is not None: to_track["rollout_num_mini_batch"] = (