diff --git a/safepo/multi_agent/macpo.py b/safepo/multi_agent/macpo.py index 27a034c..60a8fca 100644 --- a/safepo/multi_agent/macpo.py +++ b/safepo/multi_agent/macpo.py @@ -254,11 +254,11 @@ def trpo_update(self, sample): g_step_dir = self.conjugate_gradient( self.policy.actor, obs_batch, rnn_states_batch, actions_batch, masks_batch,\ - available_actions_batch, active_masks_batch, reward_loss_grad.data, nsteps=10 + available_actions_batch, active_masks_batch, reward_loss_grad.data, nsteps=self.config["conjugate_gradient_iters"] ) b_step_dir = self.conjugate_gradient( self.policy.actor, obs_batch, rnn_states_batch, actions_batch, masks_batch,\ - available_actions_batch, active_masks_batch, B_cost_loss_grad.data, nsteps=10 + available_actions_batch, active_masks_batch, B_cost_loss_grad.data, nsteps=self.config["conjugate_gradient_iters"] ) q_coef = (reward_loss_grad * g_step_dir).sum(0, keepdim=True) diff --git a/safepo/multi_agent/marl_cfg/macpo/config.yaml b/safepo/multi_agent/marl_cfg/macpo/config.yaml index cf5ab01..007bf0b 100644 --- a/safepo/multi_agent/marl_cfg/macpo/config.yaml +++ b/safepo/multi_agent/marl_cfg/macpo/config.yaml @@ -34,6 +34,7 @@ use_proper_time_limits: False target_kl: 0.016 searching_steps: 10 +conjugate_gradient_iters: 10 accept_ratio: 0.5 clip_param: 0.2 learning_iters: 5 @@ -78,5 +79,5 @@ mamujoco: gamma: 0.99 safety_gamma: 0.2 target_kl: 0.01 - learning_iters: 15 # Conjugate Gradient Iterations + learning_iters: 15 # Number of SGD Iterations entropy_coef: 0.01 \ No newline at end of file