diff --git a/domainlab/algos/trainers/fbopt_mu_controller.py b/domainlab/algos/trainers/fbopt_mu_controller.py index 3193ae225..2429a3bd6 100644 --- a/domainlab/algos/trainers/fbopt_mu_controller.py +++ b/domainlab/algos/trainers/fbopt_mu_controller.py @@ -252,3 +252,6 @@ def update_setpoint(self, epo_reg_loss, epo_task_loss): update setpoint """ return self.set_point_controller.observe(epo_reg_loss, epo_task_loss) + + def __call__(self, epoch): + return self.mmu diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index a97875bc6..907990395 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -86,10 +86,12 @@ def before_tr(self): self.flag_setpoint_updated = False if self.aconf.force_feedforward: self.set_scheduler(scheduler=HyperSchedulerWarmup) + self.flag_update_hyper_per_epoch = True + self.hyper_scheduler.set_steps(total_steps=self.aconf.warmup) else: self.set_scheduler(scheduler=HyperSchedulerFeedback) - self.set_model_with_mu() # very small value + self.set_model_with_mu(0) # very small value if self.aconf.tr_with_init_mu: self.tr_with_init_mu() @@ -113,12 +115,13 @@ def tr_with_init_mu(self): """ super().tr_epoch(-1) - def set_model_with_mu(self): + def set_model_with_mu(self, epoch): """ set model multipliers """ - self._model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) - + # self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) + self.model.hyper_update(epoch=epoch, fun_scheduler=self.hyper_scheduler) + def tr_epoch(self, epoch, flag_info=False): """ update multipliers only per epoch @@ -129,7 +132,7 @@ def tr_epoch(self, epoch, flag_info=False): self.epo_loss_tr, self.list_str_multiplier_na, miter=epoch) - self.set_model_with_mu() + self.set_model_with_mu(epoch) if hasattr(self.model, "dict_multiplier"): logger = Logger.get_logger() logger.info(f"current multiplier: {self.model.dict_multiplier}") diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 125a66acd..bf9a427e0 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -58,7 +58,7 @@ def __init__(self, list_str_y, list_d_tr, self.net_encoder = net_encoder self.net_classifier = net_classifier self.net_discriminator = net_discriminator - + @property def list_str_multiplier_na(self): return ["alpha"] diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 3e100acd8..3add87cf6 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -33,3 +33,6 @@ def test_forcesetpoint_fbopt(): args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once" utils_test_algo(args) +def test_forcefeedforward_fbopt(): + args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --epos_min=100 --es=1 --force_feedforward" + utils_test_algo(args)