diff --git a/domainlab/algos/trainers/fbopt_setpoint_ada.py b/domainlab/algos/trainers/fbopt_setpoint_ada.py index 1361708d5..08fa06fef 100644 --- a/domainlab/algos/trainers/fbopt_setpoint_ada.py +++ b/domainlab/algos/trainers/fbopt_setpoint_ada.py @@ -51,14 +51,12 @@ def transition_to(self, state): self.state_updater = state self.state_updater.accept(self) - def update_setpoint_ma(self, target): + def update_setpoint_ma(self, list_target): """ using moving average """ - temp_ma = self.coeff_ma * torch.tensor(target) - temp_ma += (1 - self.coeff_ma) * torch.tensor(self.setpoint4R) - temp_ma = temp_ma.tolist() - self.setpoint4R = temp_ma + target_ma = [self.coeff_ma * a + (1 - self.coeff_ma) *b for a, b in zip(self.setpoint4R, list_target)] + self.setpoint4R = target_ma def observe(self, epo_reg_loss, epo_task_loss): """ @@ -69,7 +67,7 @@ def observe(self, epo_reg_loss, epo_task_loss): self.state_task_loss = epo_task_loss if self.state_updater.update_setpoint(): logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO") - self.setpoint4R = self.state_epo_reg_loss + self.update_setpoint_ma(self.state_epo_reg_loss) logger.info(f"!!!!!set point updated to {self.setpoint4R}!")